krilecy commited on
Commit
9d47866
1 Parent(s): 8753b76

Upload ExtractableMistralAttention.py

Browse files
Files changed (1) hide show
  1. ExtractableMistralAttention.py +99 -0
ExtractableMistralAttention.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from typing import Optional, Tuple
5
+ import warnings
6
+ import math
7
+
8
+ from transformers.cache_utils import Cache
9
+ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
10
+
11
+ def forward(
12
+ self,
13
+ hidden_states: torch.Tensor,
14
+ attention_mask: Optional[torch.Tensor] = None,
15
+ position_ids: Optional[torch.LongTensor] = None,
16
+ past_key_value: Optional[Cache] = None,
17
+ output_attentions: bool = False,
18
+ use_cache: bool = False,
19
+ **kwargs,
20
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
21
+ if "padding_mask" in kwargs:
22
+ warnings.warn(
23
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
24
+ )
25
+ bsz, q_len, _ = hidden_states.size()
26
+
27
+ query_states = self.q_proj(hidden_states)
28
+ key_states = self.k_proj(hidden_states)
29
+ value_states = self.v_proj(hidden_states)
30
+
31
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
32
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
33
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
34
+
35
+ kv_seq_len = key_states.shape[-2]
36
+ if past_key_value is not None:
37
+ if self.layer_idx is None:
38
+ raise ValueError(
39
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
40
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
41
+ "with a layer index."
42
+ )
43
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
44
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
45
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
46
+
47
+ if past_key_value is not None:
48
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
49
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
50
+
51
+ # repeat k/v heads if n_kv_heads < n_heads
52
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
53
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
54
+
55
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
56
+
57
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
58
+ raise ValueError(
59
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
60
+ f" {attn_weights.size()}"
61
+ )
62
+
63
+ if attention_mask is not None:
64
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
65
+ raise ValueError(
66
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
67
+ )
68
+
69
+ attn_weights = attn_weights + attention_mask
70
+
71
+ # upcast attention to fp32
72
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
73
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
74
+ attn_output = torch.matmul(attn_weights, value_states)
75
+
76
+ d = value_states.size(-1) # Custom
77
+
78
+ averaging_matrix = torch.full(value_states.shape, 1/d) # Custom
79
+ custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) # Custom
80
+
81
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
82
+ raise ValueError(
83
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
84
+ f" {attn_output.size()}"
85
+ )
86
+
87
+ attn_output = attn_output.transpose(1, 2).contiguous()
88
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
89
+
90
+ custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() # Custom
91
+ custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) # Custom
92
+ self.attn_vec = custom_attn_output # Custom
93
+
94
+ attn_output = self.o_proj(attn_output)
95
+
96
+ if not output_attentions:
97
+ attn_weights = None
98
+
99
+ return attn_output, attn_weights, past_key_value