Start-GPT commited on
Commit
caab23d
·
verified ·
1 Parent(s): 0f15686

Create server/transformer_details.py

Browse files
Files changed (1) hide show
  1. server/transformer_details.py +269 -0
server/transformer_details.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for interfacing with the attentions from the front end.
3
+ """
4
+ import torch
5
+ from typing import List, Union
6
+ from abc import ABC, abstractmethod
7
+
8
+ from transformer_formatter import TransformerOutputFormatter
9
+ from utils.token_processing import reshape
10
+ from spacyface import (
11
+ BertAligner,
12
+ GPT2Aligner,
13
+ RobertaAligner,
14
+ DistilBertAligner,
15
+ auto_aligner
16
+ )
17
+
18
+ from transformers import (
19
+ BertForMaskedLM,
20
+ GPT2LMHeadModel,
21
+ RobertaForMaskedLM,
22
+ DistilBertForMaskedLM,
23
+ )
24
+
25
+ from utils.f import delegates, pick, memoize
26
+
27
+ def get_cls(class_name):
28
+ cls_type = {
29
+ 'bert-base-uncased': BertDetails,
30
+ 'bert-base-cased': BertDetails,
31
+ 'bert-large-uncased': BertDetails,
32
+ 'bert-large-cased': BertDetails,
33
+ 'gpt2': GPT2Details,
34
+ 'gpt2-medium': GPT2Details,
35
+ 'gpt2-large': GPT2Details,
36
+ 'roberta-base': RobertaDetails,
37
+ 'roberta-large': RobertaDetails,
38
+ 'roberta-large-mnli': RobertaDetails,
39
+ 'roberta-base-openai-detector': RobertaDetails,
40
+ 'roberta-large-openai-detector': RobertaDetails,
41
+ 'distilbert-base-uncased': DistilBertDetails,
42
+ 'distilbert-base-uncased-distilled-squad': DistilBertDetails,
43
+ 'distilgpt2': GPT2Details,
44
+ 'distilroberta-base': RobertaDetails,
45
+ }
46
+ return cls_type[class_name]
47
+
48
+ @memoize
49
+ def from_pretrained(model_name):
50
+ """Convert model name into appropriate transformer details"""
51
+ try: out = get_cls(model_name).from_pretrained(model_name)
52
+ except KeyError: raise KeyError(f"The model name of '{model_name}' either does not exist or is currently not supported")
53
+
54
+ return out
55
+
56
+
57
+ class TransformerBaseDetails(ABC):
58
+ """ All API calls will interact with this class to get the hidden states and attentions for any input sentence."""
59
+
60
+ def __init__(self, model, aligner):
61
+ self.model = model
62
+ self.aligner = aligner
63
+ self.model.eval()
64
+ self.forward_inputs = ['input_ids', 'attention_mask']
65
+
66
+ @classmethod
67
+ def from_pretrained(cls, model_name: str):
68
+ raise NotImplementedError(
69
+ """Inherit from this class and specify the Model and Aligner to use"""
70
+ )
71
+
72
+ def att_from_sentence(self, s: str, mask_attentions=False) -> TransformerOutputFormatter:
73
+ """Get formatted attention from a single sentence input"""
74
+ tokens = self.aligner.tokenize(s)
75
+ return self.att_from_tokens(tokens, s, add_special_tokens=True, mask_attentions=mask_attentions)
76
+
77
+ def att_from_tokens(
78
+ self, tokens: List[str], orig_sentence, add_special_tokens=False, mask_attentions=False
79
+ ) -> TransformerOutputFormatter:
80
+ """Get formatted attention from a list of tokens, using the original sentence for getting Spacy Metadata"""
81
+ ids = self.aligner.convert_tokens_to_ids(tokens)
82
+
83
+ # For GPT2, add the beginning of sentence token to the input. Note that this will work on all models but XLM
84
+ bost = self.aligner.bos_token_id
85
+ clst = self.aligner.cls_token_id
86
+ if (bost is not None) and (bost != clst) and add_special_tokens:
87
+ ids.insert(0, bost)
88
+
89
+ inputs = self.aligner.prepare_for_model(ids, add_special_tokens=add_special_tokens, return_tensors="pt")
90
+ parsed_input = self.format_model_input(inputs, mask_attentions=mask_attentions)
91
+ output = self.model(parsed_input['input_ids'], attention_mask=parsed_input['attention_mask'])
92
+ return self.format_model_output(inputs, orig_sentence, output)
93
+
94
+ def format_model_output(self, inputs, sentence:str, output, topk=5):
95
+ """Convert model output to the desired format.
96
+ Formatter additionally needs access to the tokens and the original sentence
97
+ """
98
+ hidden_state, attentions, contexts, logits = self.select_outputs(output)
99
+
100
+ words, probs = self.logits2words(logits, topk)
101
+
102
+ tokens = self.view_ids(inputs["input_ids"])
103
+ toks = self.aligner.meta_from_tokens(sentence, tokens, perform_check=False)
104
+
105
+ formatted_output = TransformerOutputFormatter(
106
+ sentence,
107
+ toks,
108
+ inputs["special_tokens_mask"],
109
+ attentions,
110
+ hidden_state,
111
+ contexts,
112
+ words,
113
+ probs.tolist()
114
+ )
115
+ return formatted_output
116
+
117
+ def select_outputs(self, output):
118
+ """Extract the desired hidden states as passed by a particular model through the output
119
+ In all cases, we care for:
120
+ - hidden state embeddings (tuple of n_layers + 1)
121
+ - attentions (tuple of n_layers)
122
+ - contexts (tuple of n_layers)
123
+ - Top predicted words
124
+ - Probabilities of top predicted words
125
+ """
126
+ logits, hidden_state, attentions, contexts = output
127
+
128
+ return hidden_state, attentions, contexts, logits
129
+
130
+ def format_model_input(self, inputs, mask_attentions=False):
131
+ """Parse the input for the model according to what is expected in the forward pass.
132
+ If not otherwise defined, outputs a dict containing the keys:
133
+ {'input_ids', 'attention_mask'}
134
+ """
135
+ return pick(self.forward_inputs, self.parse_inputs(inputs, mask_attentions=mask_attentions))
136
+
137
+ def logits2words(self, logits, topk=5):
138
+ probs, idxs = torch.topk(torch.softmax(logits.squeeze(0), 1), topk)
139
+ words = [self.aligner.convert_ids_to_tokens(i) for i in idxs]
140
+ return words, probs
141
+
142
+ def view_ids(self, ids: Union[List[int], torch.Tensor]) -> List[str]:
143
+ """View what the tokenizer thinks certain ids are"""
144
+ if type(ids) == torch.Tensor:
145
+ # Remove batch dimension
146
+ ids = ids.squeeze(0).tolist()
147
+
148
+ out = self.aligner.convert_ids_to_tokens(ids)
149
+ return out
150
+
151
+ def parse_inputs(self, inputs, mask_attentions=False):
152
+ """Parse the output from `tokenizer.prepare_for_model` to the desired attention mask from special tokens
153
+ Args:
154
+ - inputs: The output of `tokenizer.prepare_for_model`.
155
+ A dict with keys: {'special_token_mask', 'token_type_ids', 'input_ids'}
156
+ - mask_attentions: Flag indicating whether to mask the attentions or not
157
+ Returns:
158
+ Dict with keys: {'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'}
159
+ Usage:
160
+ ```
161
+ s = "test sentence"
162
+ # from raw sentence to tokens
163
+ tokens = tokenizer.tokenize(s)
164
+ # From tokens to ids
165
+ ids = tokenizer.convert_tokens_to_ids(tokens)
166
+ # From ids to input
167
+ inputs = tokenizer.prepare_for_model(ids, return_tensors='pt')
168
+ # Parse the input. Optionally mask the special tokens from the analysis.
169
+ parsed_input = parse_inputs(inputs)
170
+ # Run the model, pick from this output whatever inputs you want
171
+ from utils.f import pick
172
+ out = model(**pick(['input_ids'], parse_inputs(inputs)))
173
+ ```
174
+ """
175
+
176
+ out = inputs.copy()
177
+
178
+ # DEFINE SPECIAL TOKENS MASK
179
+ if "special_tokens_mask" not in inputs.keys():
180
+ special_tokens = set([self.aligner.unk_token_id, self.aligner.cls_token_id, self.aligner.sep_token_id, self.aligner.bos_token_id, self.aligner.eos_token_id, self.aligner.pad_token_id])
181
+ in_ids = inputs['input_ids'][0]
182
+ special_tok_mask = [1 if int(i) in special_tokens else 0 for i in in_ids]
183
+ inputs['special_tokens_mask'] = special_tok_mask
184
+
185
+ if mask_attentions:
186
+ out["attention_mask"] = torch.tensor(
187
+ [int(not i) for i in inputs.get("special_tokens_mask")]
188
+ ).unsqueeze(0)
189
+ else:
190
+ out["attention_mask"] = torch.tensor(
191
+ [1 for i in inputs.get("special_tokens_mask")]
192
+ ).unsqueeze(0)
193
+
194
+ return out
195
+
196
+
197
+ class BertDetails(TransformerBaseDetails):
198
+ @classmethod
199
+ def from_pretrained(cls, model_name: str):
200
+ return cls(
201
+ BertForMaskedLM.from_pretrained(
202
+ model_name,
203
+ output_attentions=True,
204
+ output_hidden_states=True,
205
+ output_additional_info=True,
206
+ ),
207
+ BertAligner.from_pretrained(model_name),
208
+ )
209
+
210
+
211
+ class GPT2Details(TransformerBaseDetails):
212
+ @classmethod
213
+ def from_pretrained(cls, model_name: str):
214
+ return cls(
215
+ GPT2LMHeadModel.from_pretrained(
216
+ model_name,
217
+ output_attentions=True,
218
+ output_hidden_states=True,
219
+ output_additional_info=True,
220
+ ),
221
+ GPT2Aligner.from_pretrained(model_name),
222
+ )
223
+
224
+ def select_outputs(self, output):
225
+ logits, _ , hidden_states, att, contexts = output
226
+ return hidden_states, att, contexts, logits
227
+
228
+ class RobertaDetails(TransformerBaseDetails):
229
+
230
+ @classmethod
231
+ def from_pretrained(cls, model_name: str):
232
+ return cls(
233
+ RobertaForMaskedLM.from_pretrained(
234
+ model_name,
235
+ output_attentions=True,
236
+ output_hidden_states=True,
237
+ output_additional_info=True,
238
+ ),
239
+ RobertaAligner.from_pretrained(model_name),
240
+ )
241
+
242
+ class DistilBertDetails(TransformerBaseDetails):
243
+ def __init__(self, model, aligner):
244
+ super().__init__(model, aligner)
245
+ self.forward_inputs = ['input_ids', 'attention_mask']
246
+
247
+ @classmethod
248
+ def from_pretrained(cls, model_name: str):
249
+ return cls(
250
+ DistilBertForMaskedLM.from_pretrained(
251
+ model_name,
252
+ output_attentions=True,
253
+ output_hidden_states=True,
254
+ output_additional_info=True,
255
+ ),
256
+ DistilBertAligner.from_pretrained(model_name),
257
+ )
258
+
259
+ def select_outputs(self, output):
260
+ """Extract the desired hidden states as passed by a particular model through the output
261
+ In all cases, we care for:
262
+ - hidden state embeddings (tuple of n_layers + 1)
263
+ - attentions (tuple of n_layers)
264
+ - contexts (tuple of n_layers)
265
+ """
266
+ logits, hidden_states, attentions, contexts = output
267
+
268
+ contexts = tuple([c.permute(0, 2, 1, 3).contiguous() for c in contexts])
269
+ return hidden_states, attentions, contexts, logits