lr-bert-base-uncased / handler.py
Uan Sholanbayev
fix output format
a22dbf1
raw
history blame
No virus
1.41 kB
from typing import List, Dict, Any
import numpy as np
from transformers import BertTokenizer, BertModel
import torch
import pickle
def unpickle_obj(filepath):
with open(filepath, 'rb') as f_in:
data = pickle.load(f_in)
print(f"unpickled {filepath}")
return data
class EndpointHandler():
def __init__(self, path=""):
self.model = unpickle_obj(path)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.bert = BertModel.from_pretrained('bert-base-uncased').to(self.device)
def get_embeddings(self, texts: List[str]):
inputs = self.tokenizer(texts, return_tensors='pt', truncation=True,
padding=True, max_length=512).to(self.device)
with torch.no_grad():
outputs = self.bert(**inputs)
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
queries = data['queries']
texts = data['texts']
queries_vec = self.get_embeddings(queries)
texts_vec = self.get_embeddings(texts)
diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\
.reshape(-1, len(queries_vec[0]))
return [{
"outputs": self.model.predict_proba(diff)
}]