Upload handler.py
Browse files- handler.py +14 -120
handler.py
CHANGED
@@ -12,111 +12,10 @@ from torch import Tensor
|
|
12 |
import torch.nn.functional as F
|
13 |
from torch import nn
|
14 |
|
15 |
-
# Override standard attention in Mistral to get extractable attention
|
16 |
from transformers.models.mistral.modeling_mistral import MistralAttention
|
17 |
-
from
|
18 |
-
import warnings
|
19 |
-
import math
|
20 |
-
|
21 |
-
from transformers.cache_utils import Cache
|
22 |
-
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
23 |
-
|
24 |
-
|
25 |
-
class MistralAttention(MistralAttention):
|
26 |
-
def forward(
|
27 |
-
self,
|
28 |
-
hidden_states: torch.Tensor,
|
29 |
-
attention_mask: Optional[torch.Tensor] = None,
|
30 |
-
position_ids: Optional[torch.LongTensor] = None,
|
31 |
-
past_key_value: Optional[Cache] = None,
|
32 |
-
output_attentions: bool = False,
|
33 |
-
use_cache: bool = False,
|
34 |
-
**kwargs,
|
35 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
36 |
-
if "padding_mask" in kwargs:
|
37 |
-
warnings.warn(
|
38 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
39 |
-
)
|
40 |
-
bsz, q_len, _ = hidden_states.size()
|
41 |
-
|
42 |
-
query_states = self.q_proj(hidden_states)
|
43 |
-
key_states = self.k_proj(hidden_states)
|
44 |
-
value_states = self.v_proj(hidden_states)
|
45 |
-
|
46 |
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
47 |
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
48 |
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
49 |
-
|
50 |
-
kv_seq_len = key_states.shape[-2]
|
51 |
-
if past_key_value is not None:
|
52 |
-
if self.layer_idx is None:
|
53 |
-
raise ValueError(
|
54 |
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
55 |
-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
56 |
-
"with a layer index."
|
57 |
-
)
|
58 |
-
print('kv ', type(kv_seq_len))
|
59 |
-
print('past ', type(past_key_value.get_usable_length(kv_seq_len, self.layer_idx)))
|
60 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
61 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
62 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
63 |
-
|
64 |
-
if past_key_value is not None:
|
65 |
-
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
66 |
-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
67 |
-
|
68 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
69 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
70 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
71 |
-
|
72 |
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
73 |
-
|
74 |
-
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
75 |
-
raise ValueError(
|
76 |
-
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
77 |
-
f" {attn_weights.size()}"
|
78 |
-
)
|
79 |
-
|
80 |
-
if attention_mask is not None:
|
81 |
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
82 |
-
raise ValueError(
|
83 |
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
84 |
-
)
|
85 |
-
print('attention_weights ', type(attn_weights))
|
86 |
-
print('attention_mask ', type(attention_mask))
|
87 |
-
attn_weights = attn_weights + attention_mask
|
88 |
-
|
89 |
-
# upcast attention to fp32
|
90 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
91 |
-
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
92 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
93 |
-
|
94 |
-
d = value_states.size(-1) # Custom
|
95 |
-
|
96 |
-
averaging_matrix = torch.full(value_states.shape, 1/d) # Custom
|
97 |
-
custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) # Custom
|
98 |
-
|
99 |
-
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
100 |
-
raise ValueError(
|
101 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
102 |
-
f" {attn_output.size()}"
|
103 |
-
)
|
104 |
-
|
105 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
106 |
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
107 |
-
|
108 |
-
custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() # Custom
|
109 |
-
custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) # Custom
|
110 |
-
self.attn_vec = custom_attn_output # Custom
|
111 |
-
|
112 |
-
attn_output = self.o_proj(attn_output)
|
113 |
-
|
114 |
-
if not output_attentions:
|
115 |
-
attn_weights = None
|
116 |
-
|
117 |
-
return attn_output, attn_weights, past_key_value
|
118 |
-
|
119 |
|
|
|
120 |
|
121 |
class EndpointHandler():
|
122 |
def __init__(self, model_dir=''):
|
@@ -142,7 +41,7 @@ class EndpointHandler():
|
|
142 |
self.model.eval()
|
143 |
|
144 |
|
145 |
-
def last_token_pool(last_hidden_states: Tensor,
|
146 |
attention_mask: Tensor) -> Tensor:
|
147 |
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
148 |
if left_padding:
|
@@ -153,20 +52,18 @@ class EndpointHandler():
|
|
153 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
154 |
|
155 |
|
156 |
-
def tokenize(self, text,
|
157 |
-
if
|
158 |
-
print('inst', type(self.instruction))
|
159 |
-
print('inst', type(text))
|
160 |
text = self.instruction + text
|
161 |
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
|
162 |
|
163 |
|
164 |
-
def extract_attn_vec(
|
165 |
-
return model._modules['layers'][-1].self_attn.attn_vec
|
166 |
|
167 |
|
168 |
-
def embed(self, text,
|
169 |
-
tokens = self.tokenize(text,
|
170 |
with torch.no_grad():
|
171 |
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
|
172 |
embedding = self.last_token_pool(output, tokens['attention_mask'])
|
@@ -188,20 +85,17 @@ class EndpointHandler():
|
|
188 |
Return:
|
189 |
A :obj:`list` | `dict`: will be serialized and returned
|
190 |
"""
|
191 |
-
print(data)
|
192 |
id = data.pop("id", data)
|
193 |
text = data.pop("text", data)
|
194 |
-
|
195 |
-
|
196 |
-
print(text)
|
197 |
-
print(type)
|
198 |
|
199 |
-
embeddings, attn_vec = self.embed(text,
|
200 |
embeddings = embeddings[0].tolist()
|
201 |
attn_vec = attn_vec[0].tolist()
|
202 |
|
203 |
-
if
|
204 |
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
|
205 |
|
206 |
-
elif
|
207 |
return {"id": id, "embedding": embeddings}
|
|
|
12 |
import torch.nn.functional as F
|
13 |
from torch import nn
|
14 |
|
|
|
15 |
from transformers.models.mistral.modeling_mistral import MistralAttention
|
16 |
+
from ExtractableMistralAttention import forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
MistralAttention.forward = forward
|
19 |
|
20 |
class EndpointHandler():
|
21 |
def __init__(self, model_dir=''):
|
|
|
41 |
self.model.eval()
|
42 |
|
43 |
|
44 |
+
def last_token_pool(self, last_hidden_states: Tensor,
|
45 |
attention_mask: Tensor) -> Tensor:
|
46 |
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
47 |
if left_padding:
|
|
|
52 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
53 |
|
54 |
|
55 |
+
def tokenize(self, text, request_type):
|
56 |
+
if request_type == 'query':
|
|
|
|
|
57 |
text = self.instruction + text
|
58 |
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
|
59 |
|
60 |
|
61 |
+
def extract_attn_vec(self):
|
62 |
+
return self.model._modules['layers'][-1].self_attn.attn_vec
|
63 |
|
64 |
|
65 |
+
def embed(self, text, request_type):
|
66 |
+
tokens = self.tokenize(text, request_type)
|
67 |
with torch.no_grad():
|
68 |
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
|
69 |
embedding = self.last_token_pool(output, tokens['attention_mask'])
|
|
|
85 |
Return:
|
86 |
A :obj:`list` | `dict`: will be serialized and returned
|
87 |
"""
|
|
|
88 |
id = data.pop("id", data)
|
89 |
text = data.pop("text", data)
|
90 |
+
request_type = data.pop("type", data)
|
91 |
+
|
|
|
|
|
92 |
|
93 |
+
embeddings, attn_vec = self.embed(text, request_type)
|
94 |
embeddings = embeddings[0].tolist()
|
95 |
attn_vec = attn_vec[0].tolist()
|
96 |
|
97 |
+
if request_type == 'query':
|
98 |
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
|
99 |
|
100 |
+
elif request_type == 'document':
|
101 |
return {"id": id, "embedding": embeddings}
|