File size: 8,327 Bytes
8753b76
 
 
 
 
 
 
 
 
 
 
 
dfd303b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8753b76
 
 
6891925
8753b76
 
 
 
 
6891925
8753b76
 
 
 
 
 
6891925
8753b76
 
 
 
 
6891925
8753b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# handler file for Huggingface Inference API

from typing import Dict, Any

from peft import PeftModel
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
import transformers


import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn

# Override standard attention in Mistral to get extractable attention
from transformers.models.mistral.modeling_mistral import MistralAttention
from typing import Optional, Tuple
import warnings
import math

from transformers.cache_utils import Cache
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv


class MistralAttention(MistralAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )

            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        d = value_states.size(-1)                                                                   # Custom

        averaging_matrix = torch.full(value_states.shape, 1/d)                                      # Custom
        custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix)                 # Custom

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        custom_attn_output = custom_attn_output.transpose(1, 2).contiguous()                        # Custom
        custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size)               # Custom
        self.attn_vec = custom_attn_output                                                          # Custom

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value



class EndpointHandler():
    def __init__(self, model_dir=''):
        self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n'
        self.max_length = 4096
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"


        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
        self.tokenizer.pad_token = '[PAD]'
        self.tokenizer.padding_side = 'left'

        bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)

        self.model = AutoModel.from_pretrained(
            model_dir,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            attn_implementation="eager",
        )
        self.model = PeftModel.from_pretrained(self.model, '/lora')
        self.model.eval()


    def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


    def tokenize(self, text, type):
        if type == 'query':
            text = self.instruction + text
        return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)


    def extract_attn_vec(model):
        return model._modules['layers'][-1].self_attn.attn_vec


    def embed(self, text, type):
        tokens = self.tokenize(text, type)
        with torch.no_grad():
            output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
            embedding = self.last_token_pool(output, tokens['attention_mask'])
            embedding = F.normalize(embedding, p=2, dim=1)

            attn_vec = self.extract_attn_vec(self.model)
            attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask'])
            attn_vec = F.normalize(attn_vec, p=2, dim=1)
            del output, tokens
            torch.cuda.empty_cache()
            return embedding, attn_vec


    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """

        id = data.pop("id", data)
        text = data.pop("text", data)
        type = data.pop("type", data)

        embeddings, attn_vec = self.embed(text, type)
        embeddings = embeddings[0].tolist()
        attn_vec = attn_vec[0].tolist()

        if type == 'query':
            return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
        
        elif type == 'document':
            return {"id": id, "embedding": embeddings}