krilecy commited on
Commit
dfd303b
1 Parent(s): 9d47866

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +103 -3
handler.py CHANGED
@@ -5,14 +5,114 @@ from typing import Dict, Any
5
  from peft import PeftModel
6
  from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
7
  import transformers
8
- from transformers.models.mistral.modeling_mistral import MistralAttention
9
- from ExtractableMistralAttention import forward
10
 
11
- MistralAttention.forward = forward
12
 
13
  import torch
14
  from torch import Tensor
15
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class EndpointHandler():
 
5
  from peft import PeftModel
6
  from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
7
  import transformers
 
 
8
 
 
9
 
10
  import torch
11
  from torch import Tensor
12
  import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ # Override standard attention in Mistral to get extractable attention
16
+ from transformers.models.mistral.modeling_mistral import MistralAttention
17
+ from typing import Optional, Tuple
18
+ import warnings
19
+ import math
20
+
21
+ from transformers.cache_utils import Cache
22
+ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
23
+
24
+
25
+ class MistralAttention(MistralAttention):
26
+ def forward(
27
+ self,
28
+ hidden_states: torch.Tensor,
29
+ attention_mask: Optional[torch.Tensor] = None,
30
+ position_ids: Optional[torch.LongTensor] = None,
31
+ past_key_value: Optional[Cache] = None,
32
+ output_attentions: bool = False,
33
+ use_cache: bool = False,
34
+ **kwargs,
35
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
36
+ if "padding_mask" in kwargs:
37
+ warnings.warn(
38
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
39
+ )
40
+ bsz, q_len, _ = hidden_states.size()
41
+
42
+ query_states = self.q_proj(hidden_states)
43
+ key_states = self.k_proj(hidden_states)
44
+ value_states = self.v_proj(hidden_states)
45
+
46
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
47
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
48
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
49
+
50
+ kv_seq_len = key_states.shape[-2]
51
+ if past_key_value is not None:
52
+ if self.layer_idx is None:
53
+ raise ValueError(
54
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
55
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
56
+ "with a layer index."
57
+ )
58
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
59
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
60
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
61
+
62
+ if past_key_value is not None:
63
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
64
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
65
+
66
+ # repeat k/v heads if n_kv_heads < n_heads
67
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
68
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
69
+
70
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
71
+
72
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
73
+ raise ValueError(
74
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
75
+ f" {attn_weights.size()}"
76
+ )
77
+
78
+ if attention_mask is not None:
79
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
80
+ raise ValueError(
81
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
82
+ )
83
+
84
+ attn_weights = attn_weights + attention_mask
85
+
86
+ # upcast attention to fp32
87
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
88
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
89
+ attn_output = torch.matmul(attn_weights, value_states)
90
+
91
+ d = value_states.size(-1) # Custom
92
+
93
+ averaging_matrix = torch.full(value_states.shape, 1/d) # Custom
94
+ custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) # Custom
95
+
96
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
97
+ raise ValueError(
98
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
99
+ f" {attn_output.size()}"
100
+ )
101
+
102
+ attn_output = attn_output.transpose(1, 2).contiguous()
103
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
104
+
105
+ custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() # Custom
106
+ custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) # Custom
107
+ self.attn_vec = custom_attn_output # Custom
108
+
109
+ attn_output = self.o_proj(attn_output)
110
+
111
+ if not output_attentions:
112
+ attn_weights = None
113
+
114
+ return attn_output, attn_weights, past_key_value
115
+
116
 
117
 
118
  class EndpointHandler():