serbog commited on
Commit
dafd68e
·
1 Parent(s): 1ce6479

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +62 -0
handler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ from typing import Dict, List, Any
4
+
5
+
6
+ def middle_truncate(tokenized_ids, max_length, tokenizer):
7
+ if len(tokenized_ids) <= max_length:
8
+ return tokenized_ids + [tokenizer.pad_token_id] * (
9
+ max_length - len(tokenized_ids)
10
+ )
11
+
12
+ excess_length = len(tokenized_ids) - max_length
13
+ left_remove = excess_length // 2
14
+ right_remove = excess_length - left_remove
15
+
16
+ return tokenized_ids[left_remove:-right_remove]
17
+
18
+
19
+ class EndpointHandler:
20
+ def __init__(self, path=""):
21
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
22
+ self.model = AutoModelForSequenceClassification.from_pretrained(path)
23
+ self.id2label = {
24
+ i: label for i, label in enumerate(self.model.config.id2label.values())
25
+ }
26
+ self.MAX_LENGTH = 512 # or any other max length you prefer
27
+
28
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
+ # get inputs
30
+ inputs = data.pop("inputs", data)
31
+
32
+ encodings = self.tokenizer(
33
+ inputs, padding=False, truncation=False, max_length=514
34
+ )
35
+ truncated_input_ids = middle_truncate(
36
+ encodings["input_ids"], 514, self.tokenizer
37
+ )
38
+ truncated_input_ids_array = np.array(truncated_input_ids)
39
+ attention_masks = (truncated_input_ids_array != 1).astype(int)
40
+ truncated_encodings = {
41
+ "input_ids": truncated_input_ids,
42
+ "attention_mask": attention_masks,
43
+ }
44
+
45
+ outputs = self.model(**truncated_encodings)
46
+
47
+ # transform logits to probabilities and apply threshold
48
+ probs = 1 / (1 + np.exp(-outputs.logits.detach().cpu().numpy()))
49
+ predictions = (probs >= 0.5).astype(float)
50
+
51
+ # transform predicted id's into actual label names
52
+ predicted_labels = [
53
+ self.id2label[idx]
54
+ for idx, label in enumerate(predictions[0])
55
+ if label == 1.0
56
+ ]
57
+
58
+ # You can return it in any format you like, here's an example:
59
+ return [
60
+ {"label": label, "score": prob}
61
+ for label, prob in zip(predicted_labels, probs[0])
62
+ ]