Blaxzter's picture
Update handler.py
d8f6efe
raw
history blame
3 kB
from typing import Dict, List, Any
import time
import torch
from transformers import BertModel, BertTokenizerFast
class EndpointHandler():
def __init__(self, path_to_model: str = ".", max_cache_entries: int = 10000):
# Preload all the elements you are going to need at inference.
# pseudo:
self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model)
self.model = BertModel.from_pretrained(path_to_model)
self.model = self.model.eval()
self.cache = {}
self.last_cache_cleanup = time.time()
self.max_cache_entries = max_cache_entries
def _lookup_cache(self, inputs):
cached_results = {}
uncached_inputs = []
for index, input_string in enumerate(inputs):
if input_string in self.cache:
cached_results[index] = self.cache[input_string]
else:
uncached_inputs.append((index, input_string))
return uncached_inputs, cached_results
def _store_to_cache(self, index, result):
# Add timings to cache
self.cache[index] = {
"pooler_output": result,
"last_access": time.time()
}
def _cleanup_cache(self):
current_time = time.time()
if current_time - self.last_cache_cleanup > 60 and len(self.cache) > self.max_cache_entries:
# Sort the cache by last access time
sorted_cache = sorted(self.cache.items(), key = lambda x: x[1]["last_access"])
# Remove the oldest entries
for i in range(len(self.cache) - self.max_cache_entries):
del self.cache[sorted_cache[i][0]]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
This method is called whenever a request is made to the endpoint.
:param data: { inputs [str]: list of strings to be encoded }
:return: A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data['inputs']
# Check cache for inputs
uncached_inputs, cached_results = self._lookup_cache(inputs)
output_results = {}
# Call model for uncached input
if len(uncached_inputs) != 0:
model_inputs = [input_string for _, input_string in uncached_inputs]
uncached_inputs_tokenized = self.tokenizer(model_inputs, return_tensors = "pt", padding = True)
with torch.no_grad():
uncached_output_tensor = self.model(**uncached_inputs_tokenized)
uncached_output_list = uncached_output_tensor.pooler_output.tolist()
# combine cached and uncached results
for (index, input_string), result in zip(uncached_inputs, uncached_output_list):
self._store_to_cache(input_string, result)
output_results[index] = result
self._cleanup_cache()
output_results.update(cached_results)
return [output_results[i] for i in range(len(inputs))]