Start-GPT commited on
Commit
b71e56d
1 Parent(s): 6659220

Create model_api.py

Browse files
Files changed (1) hide show
  1. server/model_api.py +153 -0
server/model_api.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Tuple
2
+
3
+ import torch
4
+ from transformers import AutoConfig, AutoTokenizer, AutoModelWithLMHead, AutoModel
5
+
6
+ from transformer_formatter import TransformerOutputFormatter
7
+ from utils.f import delegates, pick, memoize
8
+
9
+ @memoize
10
+ def get_details(mname):
11
+ return ModelDetails(mname)
12
+
13
+ def get_model_tok(mname):
14
+ conf = AutoConfig.from_pretrained(mname, output_attentions=True, output_past=False)
15
+ tok = AutoTokenizer.from_pretrained(mname, config=conf)
16
+ model = AutoModelWithLMHead.from_pretrained(mname, config=conf)
17
+ return model, tok
18
+
19
+ class ModelDetails:
20
+ """Wraps a transformer model and tokenizer to prepare inputs to the frontend visualization"""
21
+ def __init__(self, mname):
22
+ self.mname = mname
23
+ self.model, self.tok = get_model_tok(self.mname)
24
+ self.model.eval()
25
+ self.config = self.model.config
26
+
27
+ def from_sentence(self, sentence: str) -> TransformerOutputFormatter:
28
+ """Get attentions and word probabilities from a sentence. Special tokens are automatically added if a sentence is passed.
29
+
30
+ Args:
31
+ sentence: The input sentence to tokenize and analyze.
32
+ """
33
+ tokens = self.tok.tokenize(sentence)
34
+
35
+ return self.from_tokens(tokens, sentence, add_special_tokens=True)
36
+
37
+ def from_tokens(
38
+ self, tokens: List[str], orig_sentence:str, add_special_tokens:bool=False, mask_attentions:bool=False, topk:int=5
39
+ ) -> TransformerOutputFormatter:
40
+ """Get formatted attention and predictions from a list of tokens.
41
+ Args:
42
+ tokens: Tokens to analyze
43
+ orig_sentence: The sentence the tokens came from (needed to help organize the output)
44
+ add_special_tokens: Whether to add special tokens like CLS / <|endoftext|> to the tokens.
45
+ If False, assume the tokens already have the special tokens
46
+ mask_attentions: If True, do not pay attention to attention patterns to special tokens through the model.
47
+ topk: How many top predictions to report
48
+ """
49
+ ids = self.tok.convert_tokens_to_ids(tokens)
50
+
51
+ # For GPT2, add the beginning of sentence token to the input. Note that this will work on all models but XLM
52
+ bost = self.tok.bos_token_id
53
+ clst = self.tok.cls_token_id
54
+ sept = self.tok.sep_token_id
55
+ if (bost is not None) and (bost != clst)and add_special_tokens:
56
+ ids.insert(0, bost)
57
+
58
+ inputs = self.tok.prepare_for_model(ids, add_special_tokens=add_special_tokens, return_tensors="pt")
59
+ parsed_input = self.parse_inputs(inputs, mask_attentions=mask_attentions)
60
+ output = self.model(parsed_input['input_ids'], attention_mask=parsed_input['attention_mask'])
61
+
62
+ logits, atts = self.choose_logits_att(output)
63
+ words, probs = self.logits2words(logits, topk)
64
+ tokens = self.view_ids(inputs["input_ids"])
65
+
66
+ formatted_output = TransformerOutputFormatter(
67
+ orig_sentence,
68
+ tokens,
69
+ inputs["special_tokens_mask"],
70
+ atts,
71
+ words,
72
+ probs.tolist(),
73
+ self.config
74
+ )
75
+
76
+ return formatted_output
77
+
78
+ def choose_logits_att(self, out:Tuple) -> Tuple:
79
+ """Select from the model's output the logits and the attentions, switching on model name
80
+
81
+ Args:
82
+ out: Output from the model's forward pass
83
+ Returns:
84
+ (logits: tensor((bs, N)), attentions: Tuple[tensor(())])
85
+ """
86
+ if 't5' in self.mname:
87
+ logits, _, atts = out
88
+ else:
89
+ logits, atts = out
90
+
91
+ return logits, atts
92
+
93
+ def logits2words(self, logits, topk):
94
+ """Convert logit probabilities into words from the tokenizer's vocabulary.
95
+
96
+ """
97
+ probs, idxs = torch.topk(torch.softmax(logits.squeeze(0), 1), topk)
98
+ words = [self.tok.convert_ids_to_tokens(i) for i in idxs]
99
+ return words, probs
100
+
101
+ def view_ids(self, ids: Union[List[int], torch.Tensor]) -> List[str]:
102
+ """View what the tokenizer thinks certain ids are for a single input"""
103
+ if type(ids) == torch.Tensor:
104
+ # Remove batch dimension
105
+ ids = ids.squeeze(0).tolist()
106
+
107
+ out = self.tok.convert_ids_to_tokens(ids)
108
+ return out
109
+
110
+ def parse_inputs(self, inputs, mask_attentions=False):
111
+ """Parse the output from `tokenizer.prepare_for_model` to the desired attention mask from special tokens
112
+ Args:
113
+ - inputs: The output of `tokenizer.prepare_for_model`.
114
+ A dict with keys: {'special_token_mask', 'token_type_ids', 'input_ids'}
115
+ - mask_attentions: Flag indicating whether to mask the attentions or not
116
+ Returns:
117
+ Dict with keys: {'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'}
118
+ Usage:
119
+ ```
120
+ s = "test sentence"
121
+ # from raw sentence to tokens
122
+ tokens = tokenizer.tokenize(s)
123
+ # From tokens to ids
124
+ ids = tokenizer.convert_tokens_to_ids(tokens)
125
+ # From ids to input
126
+ inputs = tokenizer.prepare_for_model(ids, return_tensors='pt')
127
+ # Parse the input. Optionally mask the special tokens from the analysis.
128
+ parsed_input = parse_inputs(inputs)
129
+ # Run the model, pick from this output whatever inputs you want
130
+ from utils.f import pick
131
+ out = model(**pick(['input_ids'], parse_inputs(inputs)))
132
+ ```
133
+ """
134
+
135
+ out = inputs.copy()
136
+
137
+ # DEFINE SPECIAL TOKENS MASK
138
+ if "special_tokens_mask" not in inputs.keys():
139
+ special_tokens = set([self.tok.unk_token_id, self.tok.cls_token_id, self.tok.sep_token_id, self.tok.bos_token_id, self.tok.eos_token_id, self.tok.pad_token_id])
140
+ in_ids = inputs['input_ids'][0]
141
+ special_tok_mask = [1 if int(i) in special_tokens else 0 for i in in_ids]
142
+ inputs['special_tokens_mask'] = special_tok_mask
143
+
144
+ if mask_attentions:
145
+ out["attention_mask"] = torch.tensor(
146
+ [int(not i) for i in inputs.get("special_tokens_mask")]
147
+ ).unsqueeze(0)
148
+ else:
149
+ out["attention_mask"] = torch.tensor(
150
+ [1 for i in inputs.get("special_tokens_mask")]
151
+ ).unsqueeze(0)
152
+
153
+ return out