from typing import List, Dict, Any import numpy as np from transformers import BertTokenizer, BertModel import torch import pickle def unpickle_obj(filepath): with open(filepath, 'rb') as f_in: data = pickle.load(f_in) print(f"unpickled {filepath}") return data class EndpointHandler(): def __init__(self, path=""): self.model = unpickle_obj(f"{path}/bert_lr.pkl") self.tokenizer = BertTokenizer.from_pretrained(path, local_files_only=True) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.bert = BertModel.from_pretrained(path).to(self.device) def get_embeddings(self, texts: List[str]): inputs = self.tokenizer(texts, return_tensors='pt', truncation=True, padding=True, max_length=512).to(self.device) with torch.no_grad(): outputs = self.bert(**inputs) return outputs.last_hidden_state.mean(dim=1).cpu().numpy() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.pop("inputs",data) queries = inputs['queries'] texts = inputs['texts'] queries_vec = self.get_embeddings(queries) texts_vec = self.get_embeddings(texts) diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\ .reshape(-1, len(queries_vec[0])) return [{ "outputs": self.model.predict_proba(diff).tolist() }]