from typing import Dict, List, Any import time import torch from transformers import BertModel, BertTokenizerFast class EndpointHandler(): def __init__(self, path_to_model: str = ".", max_cache_entries: int = 10000): # Preload all the elements you are going to need at inference. # pseudo: self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model) self.model = BertModel.from_pretrained(path_to_model) self.model = self.model.eval() self.cache = {} self.last_cache_cleanup = time.time() self.max_cache_entries = max_cache_entries def _lookup_cache(self, inputs): cached_results = {} uncached_inputs = [] for index, input_string in enumerate(inputs): if input_string in self.cache: cached_results[index] = self.cache[input_string] else: uncached_inputs.append((index, input_string)) return uncached_inputs, cached_results def _store_to_cache(self, index, result): # Add timings to cache self.cache[index] = { "pooler_output": result, "last_access": time.time() } def _cleanup_cache(self): current_time = time.time() if current_time - self.last_cache_cleanup > 60 and len(self.cache) > self.max_cache_entries: # Sort the cache by last access time sorted_cache = sorted(self.cache.items(), key = lambda x: x[1]["last_access"]) # Remove the oldest entries for i in range(len(self.cache) - self.max_cache_entries): del self.cache[sorted_cache[i][0]] def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ This method is called whenever a request is made to the endpoint. :param data: { inputs [str]: list of strings to be encoded } :return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data['inputs'] # Check cache for inputs uncached_inputs, cached_results = self._lookup_cache(inputs) output_results = {} # Call model for uncached input if len(uncached_inputs) != 0: model_inputs = [input_string for _, input_string in uncached_inputs] uncached_inputs_tokenized = self.tokenizer(model_inputs, return_tensors = "pt", padding = True) with torch.no_grad(): uncached_output_tensor = self.model(**uncached_inputs_tokenized) uncached_output_list = uncached_output_tensor.pooler_output.tolist() # combine cached and uncached results for (index, input_string), result in zip(uncached_inputs, uncached_output_list): self._store_to_cache(input_string, result) output_results[index] = result self._cleanup_cache() output_results.update(cached_results) return [output_results[i] for i in range(len(inputs))]