from typing import Dict, List, Any import numpy as np import pickle from sklearn.preprocessing import MultiLabelBinarizer from transformers import AutoTokenizer import torch from eurovoc import EurovocTagger BERT_MODEL_NAME = "EuropeanParliament/EUBERT" MAX_LEN = 512 TEXT_MAX_LEN = MAX_LEN * 50 tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME) class EndpointHandler: mlb = MultiLabelBinarizer() def __init__(self, path=""): self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb")) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = EurovocTagger.from_pretrained(path, bert_model_name=BERT_MODEL_NAME, n_classes=len(self.mlb.classes_), map_location=self.device) self.model.eval() self.model.freeze() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ text = data.pop("inputs", data) topk = data.pop("topk", 5) threshold = data.pop("threshold", 0.16) debug = data.pop("debug", False) prediction = self.get_prediction(text) results = [{"label": label, "score": float(score)} for label, score in zip(self.mlb.classes_, prediction[0].tolist())] results = sorted(results, key=lambda x: x["score"], reverse=True) results = [r for r in results if r["score"] > threshold] results = results[:topk] if debug: return {"results": results, "values": prediction, "input": text} else: return {"results": results} def get_prediction(self, text): # split text into chunks of MAX_LEN and get average prediction for each chunk chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)] predictions = [self._get_prediction(chunk) for chunk in chunks] predictions = np.array(predictions).mean(axis=0) return predictions def _get_prediction(self, text): item = tokenizer.encode_plus( text, add_special_tokens=True, max_length=MAX_LEN, return_token_type_ids=False, padding="max_length", truncation=True, return_attention_mask=True, return_tensors='pt') _, prediction = self.model(item["input_ids"], item["attention_mask"]) prediction = prediction.cpu().detach().numpy() print(text, prediction) return prediction