|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from typing import Dict, List, Any |
|
|
|
|
|
def middle_truncate(tokenized_ids, max_length, tokenizer): |
|
if len(tokenized_ids) <= max_length: |
|
return tokenized_ids + [tokenizer.pad_token_id] * ( |
|
max_length - len(tokenized_ids) |
|
) |
|
|
|
excess_length = len(tokenized_ids) - max_length |
|
left_remove = excess_length // 2 |
|
right_remove = excess_length - left_remove |
|
|
|
return tokenized_ids[left_remove:-right_remove] |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(path) |
|
self.id2label = { |
|
i: label for i, label in enumerate(self.model.config.id2label.values()) |
|
} |
|
self.MAX_LENGTH = 512 |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
encodings = self.tokenizer( |
|
inputs, padding=False, truncation=False, max_length=514 |
|
) |
|
truncated_input_ids = middle_truncate( |
|
encodings["input_ids"], 514, self.tokenizer |
|
) |
|
truncated_input_ids_array = np.array(truncated_input_ids) |
|
attention_masks = (truncated_input_ids_array != 1).astype(int) |
|
truncated_encodings = { |
|
"input_ids": truncated_input_ids, |
|
"attention_mask": attention_masks, |
|
} |
|
|
|
outputs = self.model(**truncated_encodings) |
|
|
|
|
|
probs = 1 / (1 + np.exp(-outputs.logits.detach().cpu().numpy())) |
|
predictions = (probs >= 0.5).astype(float) |
|
|
|
|
|
predicted_labels = [ |
|
self.id2label[idx] |
|
for idx, label in enumerate(predictions[0]) |
|
if label == 1.0 |
|
] |
|
|
|
|
|
return [ |
|
{"label": label, "score": prob} |
|
for label, prob in zip(predicted_labels, probs[0]) |
|
] |
|
|