serbog's picture
Upload handler.py
646ce9c
raw
history blame
1.82 kB
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, List, Any
def middle_truncate(tokenized_ids, max_length, tokenizer):
if len(tokenized_ids) <= max_length:
return tokenized_ids + [tokenizer.pad_token_id] * (
max_length - len(tokenized_ids)
)
excess_length = len(tokenized_ids) - max_length
left_remove = excess_length // 2
right_remove = excess_length - left_remove
return tokenized_ids[left_remove:-right_remove]
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.MAX_LENGTH = 512 # or any other max length you prefer
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# get inputs
inputs = data.pop("inputs", data)
encodings = self.tokenizer(inputs, padding=False, truncation=False)
truncated_input_ids = middle_truncate(
encodings["input_ids"][0].tolist(), self.MAX_LENGTH, self.tokenizer
)
attention_masks = [
int(token_id != self.tokenizer.pad_token_id)
for token_id in truncated_input_ids
]
truncated_encodings = {
"input_ids": torch.tensor([truncated_input_ids]),
"attention_mask": torch.tensor([attention_masks]),
}
truncated_encodings.set_format("torch")
outputs = self.model(**truncated_encodings)
# transform logits to probabilities and apply threshold
probs = 1 / (1 + np.exp(-outputs.logits.detach().cpu().numpy()))
# You can return it in any format you like, here's an example:
return [{"scores": probs}]