File size: 3,012 Bytes
5d526dc
7909a1d
5d526dc
 
 
 
 
 
7909a1d
5d526dc
 
 
 
 
7909a1d
 
 
 
 
 
 
 
 
 
d42ba8d
7909a1d
 
 
 
 
 
 
 
 
 
 
 
 
 
d8f6efe
7909a1d
 
 
 
 
 
5d526dc
 
 
 
 
 
 
7909a1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d526dc
7909a1d
5d526dc
7909a1d
5d526dc
7909a1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Dict, List, Any
import time

import torch
from transformers import BertModel, BertTokenizerFast


class EndpointHandler():
    def __init__(self, path_to_model: str = ".", max_cache_entries: int = 10000):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model)
        self.model = BertModel.from_pretrained(path_to_model)
        self.model = self.model.eval()
        self.cache = {}
        self.last_cache_cleanup = time.time()
        self.max_cache_entries = max_cache_entries

    def _lookup_cache(self, inputs):
        cached_results = {}
        uncached_inputs = []

        for index, input_string in enumerate(inputs):
            if input_string in self.cache:
                cached_results[index] = self.cache[input_string]["pooler_output"]
            else:
                uncached_inputs.append((index, input_string))

        return uncached_inputs, cached_results

    def _store_to_cache(self, index, result):
        # Add timings to cache
        self.cache[index] = {
            "pooler_output": result,
            "last_access": time.time()
        }

    def _cleanup_cache(self):
        current_time = time.time()
        if current_time - self.last_cache_cleanup > 60 and len(self.cache) > self.max_cache_entries:
            # Sort the cache by last access time
            sorted_cache = sorted(self.cache.items(), key = lambda x: x[1]["last_access"])

            # Remove the oldest entries
            for i in range(len(self.cache) - self.max_cache_entries):
                del self.cache[sorted_cache[i][0]]

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        This method is called whenever a request is made to the endpoint.
        :param data: { inputs [str]: list of strings to be encoded }
        :return: A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data['inputs']

        # Check cache for inputs
        uncached_inputs, cached_results = self._lookup_cache(inputs)

        output_results = {}

        # Call model for uncached input
        if len(uncached_inputs) != 0:
            model_inputs = [input_string for _, input_string in uncached_inputs]
            uncached_inputs_tokenized = self.tokenizer(model_inputs, return_tensors = "pt", padding = True)

            with torch.no_grad():
                uncached_output_tensor = self.model(**uncached_inputs_tokenized)

            uncached_output_list = uncached_output_tensor.pooler_output.tolist()

            # combine cached and uncached results
            for (index, input_string), result in zip(uncached_inputs, uncached_output_list):
                self._store_to_cache(input_string, result)
                output_results[index] = result

            self._cleanup_cache()

        output_results.update(cached_results)

        return [output_results[i] for i in range(len(inputs))]