File size: 8,327 Bytes
8753b76 dfd303b 8753b76 6891925 8753b76 6891925 8753b76 6891925 8753b76 6891925 8753b76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# handler file for Huggingface Inference API
from typing import Dict, Any
from peft import PeftModel
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
import transformers
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn
# Override standard attention in Mistral to get extractable attention
from transformers.models.mistral.modeling_mistral import MistralAttention
from typing import Optional, Tuple
import warnings
import math
from transformers.cache_utils import Cache
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
class MistralAttention(MistralAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
d = value_states.size(-1) # Custom
averaging_matrix = torch.full(value_states.shape, 1/d) # Custom
custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) # Custom
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() # Custom
custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) # Custom
self.attn_vec = custom_attn_output # Custom
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class EndpointHandler():
def __init__(self, model_dir=''):
self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n'
self.max_length = 4096
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.tokenizer.pad_token = '[PAD]'
self.tokenizer.padding_side = 'left'
bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)
self.model = AutoModel.from_pretrained(
model_dir,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager",
)
self.model = PeftModel.from_pretrained(self.model, '/lora')
self.model.eval()
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def tokenize(self, text, type):
if type == 'query':
text = self.instruction + text
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
def extract_attn_vec(model):
return model._modules['layers'][-1].self_attn.attn_vec
def embed(self, text, type):
tokens = self.tokenize(text, type)
with torch.no_grad():
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
embedding = self.last_token_pool(output, tokens['attention_mask'])
embedding = F.normalize(embedding, p=2, dim=1)
attn_vec = self.extract_attn_vec(self.model)
attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask'])
attn_vec = F.normalize(attn_vec, p=2, dim=1)
del output, tokens
torch.cuda.empty_cache()
return embedding, attn_vec
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
id = data.pop("id", data)
text = data.pop("text", data)
type = data.pop("type", data)
embeddings, attn_vec = self.embed(text, type)
embeddings = embeddings[0].tolist()
attn_vec = attn_vec[0].tolist()
if type == 'query':
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
elif type == 'document':
return {"id": id, "embedding": embeddings} |