andrewluo commited on
Commit
f44ab2c
·
1 Parent(s): a630f6d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -11
handler.py CHANGED
@@ -22,14 +22,19 @@ class EndpointHandler():
22
  text = data.pop("text", data)
23
  tokens = self.tokenizer(text, return_tensors='pt', padding=True)
24
  output = self.model(**tokens)
25
- vec = torch.max(
26
- torch.log(
27
- 1 + torch.relu(output.logits)
28
- ) * tokens.attention_mask.unsqueeze(-1),
29
- dim=1)[0].squeeze()
30
- cols = vec.nonzero().squeeze().cpu().tolist()
31
- # extract the non-zero values
32
- weights = vec[cols].cpu().tolist()
33
- # use to create a dictionary of token ID to weight
34
- sparse_dict = dict(zip(map(str, cols), weights))
35
- return sparse_dict
 
 
 
 
 
 
22
  text = data.pop("text", data)
23
  tokens = self.tokenizer(text, return_tensors='pt', padding=True)
24
  output = self.model(**tokens)
25
+ results = []
26
+ for idx, x in enumerate(outputs.logits):
27
+ mask = tokens.attention_mask[idx]
28
+ mask = mask[None,:]
29
+ vec = torch.max(
30
+ torch.log(
31
+ 1 + torch.relu(x)
32
+ ) * mask.unsqueeze(-1),
33
+ dim=1)[0].squeeze()
34
+ cols = vec.nonzero().squeeze().cpu().tolist()
35
+ # extract the non-zero values
36
+ weights = vec[cols].cpu().tolist()
37
+ # use to create a dictionary of token ID to weight
38
+ sparse_dict = dict(zip(map(str, cols), weights))
39
+ results.append(sparse_dict)
40
+ return results