File size: 2,216 Bytes
dafd68e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, List, Any


def middle_truncate(tokenized_ids, max_length, tokenizer):
    if len(tokenized_ids) <= max_length:
        return tokenized_ids + [tokenizer.pad_token_id] * (
            max_length - len(tokenized_ids)
        )

    excess_length = len(tokenized_ids) - max_length
    left_remove = excess_length // 2
    right_remove = excess_length - left_remove

    return tokenized_ids[left_remove:-right_remove]


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.id2label = {
            i: label for i, label in enumerate(self.model.config.id2label.values())
        }
        self.MAX_LENGTH = 512  # or any other max length you prefer

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # get inputs
        inputs = data.pop("inputs", data)

        encodings = self.tokenizer(
            inputs, padding=False, truncation=False, max_length=514
        )
        truncated_input_ids = middle_truncate(
            encodings["input_ids"], 514, self.tokenizer
        )
        truncated_input_ids_array = np.array(truncated_input_ids)
        attention_masks = (truncated_input_ids_array != 1).astype(int)
        truncated_encodings = {
            "input_ids": truncated_input_ids,
            "attention_mask": attention_masks,
        }

        outputs = self.model(**truncated_encodings)

        # transform logits to probabilities and apply threshold
        probs = 1 / (1 + np.exp(-outputs.logits.detach().cpu().numpy()))
        predictions = (probs >= 0.5).astype(float)

        # transform predicted id's into actual label names
        predicted_labels = [
            self.id2label[idx]
            for idx, label in enumerate(predictions[0])
            if label == 1.0
        ]

        # You can return it in any format you like, here's an example:
        return [
            {"label": label, "score": prob}
            for label, prob in zip(predicted_labels, probs[0])
        ]