|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from typing import Dict, List, Any |
|
|
|
|
|
def middle_truncate(tokenized_ids, max_length, tokenizer): |
|
if len(tokenized_ids) <= max_length: |
|
return tokenized_ids + [tokenizer.pad_token_id] * ( |
|
max_length - len(tokenized_ids) |
|
) |
|
|
|
excess_length = len(tokenized_ids) - max_length |
|
left_remove = excess_length // 2 |
|
right_remove = excess_length - left_remove |
|
|
|
return tokenized_ids[left_remove:-right_remove] |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(path) |
|
self.MAX_LENGTH = 512 |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
encodings = self.tokenizer(inputs, padding=False, truncation=False) |
|
|
|
truncated_input_ids = middle_truncate( |
|
encodings["input_ids"][0].tolist(), self.MAX_LENGTH, self.tokenizer |
|
) |
|
|
|
attention_masks = [ |
|
int(token_id != self.tokenizer.pad_token_id) |
|
for token_id in truncated_input_ids |
|
] |
|
truncated_encodings = { |
|
"input_ids": torch.tensor([truncated_input_ids]), |
|
"attention_mask": torch.tensor([attention_masks]), |
|
} |
|
|
|
truncated_encodings.set_format("torch") |
|
|
|
outputs = self.model(**truncated_encodings) |
|
|
|
|
|
probs = 1 / (1 + np.exp(-outputs.logits.detach().cpu().numpy())) |
|
|
|
|
|
return [{"scores": probs}] |
|
|