krilecy commited on
Commit
a2cbf75
1 Parent(s): 22ad347

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -120
handler.py CHANGED
@@ -12,111 +12,10 @@ from torch import Tensor
12
  import torch.nn.functional as F
13
  from torch import nn
14
 
15
- # Override standard attention in Mistral to get extractable attention
16
  from transformers.models.mistral.modeling_mistral import MistralAttention
17
- from typing import Optional, Tuple
18
- import warnings
19
- import math
20
-
21
- from transformers.cache_utils import Cache
22
- from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
23
-
24
-
25
- class MistralAttention(MistralAttention):
26
- def forward(
27
- self,
28
- hidden_states: torch.Tensor,
29
- attention_mask: Optional[torch.Tensor] = None,
30
- position_ids: Optional[torch.LongTensor] = None,
31
- past_key_value: Optional[Cache] = None,
32
- output_attentions: bool = False,
33
- use_cache: bool = False,
34
- **kwargs,
35
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
36
- if "padding_mask" in kwargs:
37
- warnings.warn(
38
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
39
- )
40
- bsz, q_len, _ = hidden_states.size()
41
-
42
- query_states = self.q_proj(hidden_states)
43
- key_states = self.k_proj(hidden_states)
44
- value_states = self.v_proj(hidden_states)
45
-
46
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
47
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
48
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
49
-
50
- kv_seq_len = key_states.shape[-2]
51
- if past_key_value is not None:
52
- if self.layer_idx is None:
53
- raise ValueError(
54
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
55
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
56
- "with a layer index."
57
- )
58
- print('kv ', type(kv_seq_len))
59
- print('past ', type(past_key_value.get_usable_length(kv_seq_len, self.layer_idx)))
60
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
61
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
62
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
63
-
64
- if past_key_value is not None:
65
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
66
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
67
-
68
- # repeat k/v heads if n_kv_heads < n_heads
69
- key_states = repeat_kv(key_states, self.num_key_value_groups)
70
- value_states = repeat_kv(value_states, self.num_key_value_groups)
71
-
72
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
73
-
74
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
75
- raise ValueError(
76
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
77
- f" {attn_weights.size()}"
78
- )
79
-
80
- if attention_mask is not None:
81
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
82
- raise ValueError(
83
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
84
- )
85
- print('attention_weights ', type(attn_weights))
86
- print('attention_mask ', type(attention_mask))
87
- attn_weights = attn_weights + attention_mask
88
-
89
- # upcast attention to fp32
90
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
91
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
92
- attn_output = torch.matmul(attn_weights, value_states)
93
-
94
- d = value_states.size(-1) # Custom
95
-
96
- averaging_matrix = torch.full(value_states.shape, 1/d) # Custom
97
- custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) # Custom
98
-
99
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
100
- raise ValueError(
101
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
102
- f" {attn_output.size()}"
103
- )
104
-
105
- attn_output = attn_output.transpose(1, 2).contiguous()
106
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
107
-
108
- custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() # Custom
109
- custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) # Custom
110
- self.attn_vec = custom_attn_output # Custom
111
-
112
- attn_output = self.o_proj(attn_output)
113
-
114
- if not output_attentions:
115
- attn_weights = None
116
-
117
- return attn_output, attn_weights, past_key_value
118
-
119
 
 
120
 
121
  class EndpointHandler():
122
  def __init__(self, model_dir=''):
@@ -142,7 +41,7 @@ class EndpointHandler():
142
  self.model.eval()
143
 
144
 
145
- def last_token_pool(last_hidden_states: Tensor,
146
  attention_mask: Tensor) -> Tensor:
147
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
148
  if left_padding:
@@ -153,20 +52,18 @@ class EndpointHandler():
153
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
154
 
155
 
156
- def tokenize(self, text, type):
157
- if type == 'query':
158
- print('inst', type(self.instruction))
159
- print('inst', type(text))
160
  text = self.instruction + text
161
  return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
162
 
163
 
164
- def extract_attn_vec(model):
165
- return model._modules['layers'][-1].self_attn.attn_vec
166
 
167
 
168
- def embed(self, text, type):
169
- tokens = self.tokenize(text, type)
170
  with torch.no_grad():
171
  output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
172
  embedding = self.last_token_pool(output, tokens['attention_mask'])
@@ -188,20 +85,17 @@ class EndpointHandler():
188
  Return:
189
  A :obj:`list` | `dict`: will be serialized and returned
190
  """
191
- print(data)
192
  id = data.pop("id", data)
193
  text = data.pop("text", data)
194
- type = data.pop("type", data)
195
- print(id)
196
- print(text)
197
- print(type)
198
 
199
- embeddings, attn_vec = self.embed(text, type)
200
  embeddings = embeddings[0].tolist()
201
  attn_vec = attn_vec[0].tolist()
202
 
203
- if type == 'query':
204
  return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
205
 
206
- elif type == 'document':
207
  return {"id": id, "embedding": embeddings}
 
12
  import torch.nn.functional as F
13
  from torch import nn
14
 
 
15
  from transformers.models.mistral.modeling_mistral import MistralAttention
16
+ from ExtractableMistralAttention import forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ MistralAttention.forward = forward
19
 
20
  class EndpointHandler():
21
  def __init__(self, model_dir=''):
 
41
  self.model.eval()
42
 
43
 
44
+ def last_token_pool(self, last_hidden_states: Tensor,
45
  attention_mask: Tensor) -> Tensor:
46
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
47
  if left_padding:
 
52
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
53
 
54
 
55
+ def tokenize(self, text, request_type):
56
+ if request_type == 'query':
 
 
57
  text = self.instruction + text
58
  return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
59
 
60
 
61
+ def extract_attn_vec(self):
62
+ return self.model._modules['layers'][-1].self_attn.attn_vec
63
 
64
 
65
+ def embed(self, text, request_type):
66
+ tokens = self.tokenize(text, request_type)
67
  with torch.no_grad():
68
  output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
69
  embedding = self.last_token_pool(output, tokens['attention_mask'])
 
85
  Return:
86
  A :obj:`list` | `dict`: will be serialized and returned
87
  """
 
88
  id = data.pop("id", data)
89
  text = data.pop("text", data)
90
+ request_type = data.pop("type", data)
91
+
 
 
92
 
93
+ embeddings, attn_vec = self.embed(text, request_type)
94
  embeddings = embeddings[0].tolist()
95
  attn_vec = attn_vec[0].tolist()
96
 
97
+ if request_type == 'query':
98
  return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
99
 
100
+ elif request_type == 'document':
101
  return {"id": id, "embedding": embeddings}