File size: 1,820 Bytes
dafd68e
646ce9c
dafd68e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646ce9c
55990e0
dafd68e
55990e0
dafd68e
55990e0
 
 
 
 
dafd68e
55990e0
 
dafd68e
 
646ce9c
 
dafd68e
 
 
 
 
 
646ce9c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, List, Any


def middle_truncate(tokenized_ids, max_length, tokenizer):
    if len(tokenized_ids) <= max_length:
        return tokenized_ids + [tokenizer.pad_token_id] * (
            max_length - len(tokenized_ids)
        )

    excess_length = len(tokenized_ids) - max_length
    left_remove = excess_length // 2
    right_remove = excess_length - left_remove

    return tokenized_ids[left_remove:-right_remove]


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.MAX_LENGTH = 512  # or any other max length you prefer

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # get inputs
        inputs = data.pop("inputs", data)

        encodings = self.tokenizer(inputs, padding=False, truncation=False)

        truncated_input_ids = middle_truncate(
            encodings["input_ids"][0].tolist(), self.MAX_LENGTH, self.tokenizer
        )

        attention_masks = [
            int(token_id != self.tokenizer.pad_token_id)
            for token_id in truncated_input_ids
        ]
        truncated_encodings = {
            "input_ids": torch.tensor([truncated_input_ids]),
            "attention_mask": torch.tensor([attention_masks]),
        }

        truncated_encodings.set_format("torch")

        outputs = self.model(**truncated_encodings)

        # transform logits to probabilities and apply threshold
        probs = 1 / (1 + np.exp(-outputs.logits.detach().cpu().numpy()))

        # You can return it in any format you like, here's an example:
        return [{"scores": probs}]