Andrew Luo commited on
Commit
c98f9e5
1 Parent(s): 07655e0

customer handler

Browse files
Files changed (2) hide show
  1. handler.py +35 -0
  2. requirements.txt +0 -0
handler.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
3
+ import torch
4
+
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ tokenizer = AutoTokenizer.from_pretrained(path)
9
+ model = AutoModelForMaskedLM.from_pretrained(path)
10
+ self.tokenizer = tokenizer
11
+ self.model = model
12
+
13
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
+ """
15
+ data args:
16
+ inputs (:obj: `str`)
17
+ date (:obj: `str`)
18
+ Return:
19
+ A :obj:`list` | `dict`: will be serialized and returned
20
+ """
21
+ # get inputs
22
+ tokens = self.tokenizer(text, return_tensors='pt')
23
+ output = self.model(**tokens)
24
+ vec = torch.max(
25
+ torch.log(
26
+ 1 + torch.relu(output.logits)
27
+ ) * tokens.attention_mask.unsqueeze(-1),
28
+ dim=1)[0].squeeze()
29
+ instruction = data.pop("instruction", data)
30
+ cols = vec.nonzero().squeeze().cpu().tolist()
31
+ # extract the non-zero values
32
+ weights = vec[cols].cpu().tolist()
33
+ # use to create a dictionary of token ID to weight
34
+ sparse_dict = dict(zip(cols, weights))
35
+ return sparse_dict
requirements.txt ADDED
File without changes