File size: 1,407 Bytes
8f18779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22dbf1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import List, Dict, Any
import numpy as np
from transformers import BertTokenizer, BertModel
import torch
import pickle


def unpickle_obj(filepath):
    with open(filepath, 'rb') as f_in:
        data = pickle.load(f_in)
    print(f"unpickled {filepath}")
    return data


class EndpointHandler():
    def __init__(self, path=""):
        self.model = unpickle_obj(path)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.bert = BertModel.from_pretrained('bert-base-uncased').to(self.device)

    def get_embeddings(self, texts: List[str]):
        inputs = self.tokenizer(texts, return_tensors='pt', truncation=True,
                                padding=True, max_length=512).to(self.device)
        with torch.no_grad():
            outputs = self.bert(**inputs)
        return outputs.last_hidden_state.mean(dim=1).cpu().numpy()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        queries = data['queries']
        texts = data['texts']
        queries_vec = self.get_embeddings(queries)
        texts_vec = self.get_embeddings(texts)
        diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\
            .reshape(-1, len(queries_vec[0]))
        return [{
            "outputs": self.model.predict_proba(diff)
        }]