Blaxzter commited on
Commit
7909a1d
1 Parent(s): a38b233

Update handler.py

Browse files

Add cache to handler

Files changed (1) hide show
  1. handler.py +58 -5
handler.py CHANGED
@@ -1,16 +1,49 @@
1
  from typing import Dict, List, Any
 
2
 
3
  import torch
4
  from transformers import BertModel, BertTokenizerFast
5
 
6
 
7
  class EndpointHandler():
8
- def __init__(self, path_to_model: str = '.'):
9
  # Preload all the elements you are going to need at inference.
10
  # pseudo:
11
  self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model)
12
  self.model = BertModel.from_pretrained(path_to_model)
13
  self.model = self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
@@ -18,10 +51,30 @@ class EndpointHandler():
18
  :param data: { inputs [str]: list of strings to be encoded }
19
  :return: A :obj:`list` | `dict`: will be serialized and returned
20
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- inputs = self.tokenizer(data['inputs'], return_tensors = "pt", padding = True)
23
 
24
- with torch.no_grad():
25
- outputs = self.model(**inputs)
26
 
27
- return outputs.pooler_output.tolist()
 
1
  from typing import Dict, List, Any
2
+ import time
3
 
4
  import torch
5
  from transformers import BertModel, BertTokenizerFast
6
 
7
 
8
  class EndpointHandler():
9
+ def __init__(self, path_to_model: str = ".", max_cache_entries: int = 10000):
10
  # Preload all the elements you are going to need at inference.
11
  # pseudo:
12
  self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model)
13
  self.model = BertModel.from_pretrained(path_to_model)
14
  self.model = self.model.eval()
15
+ self.cache = {}
16
+ self.last_cache_cleanup = time.time()
17
+ self.max_cache_entries = max_cache_entries
18
+
19
+ def _lookup_cache(self, inputs):
20
+ cached_results = {}
21
+ uncached_inputs = []
22
+
23
+ for index, input_string in enumerate(inputs):
24
+ if input_string in self.cache:
25
+ cached_results[index] = self.cache[input_string]
26
+ else:
27
+ uncached_inputs.append((index, input_string))
28
+
29
+ return uncached_inputs, cached_results
30
+
31
+ def _store_to_cache(self, index, result):
32
+ # Add timings to cache
33
+ self.cache[index] = {
34
+ "pooler_output": result,
35
+ "last_access": time.time()
36
+ }
37
+
38
+ def _cleanup_cache(self):
39
+ current_time = time.time()
40
+ if current_time - self.last_cache_cleanup > 5 and len(self.cache) > self.max_cache_entries:
41
+ # Sort the cache by last access time
42
+ sorted_cache = sorted(self.cache.items(), key = lambda x: x[1]["last_access"])
43
+
44
+ # Remove the oldest entries
45
+ for i in range(len(self.cache) - self.max_cache_entries):
46
+ del self.cache[sorted_cache[i][0]]
47
 
48
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
49
  """
 
51
  :param data: { inputs [str]: list of strings to be encoded }
52
  :return: A :obj:`list` | `dict`: will be serialized and returned
53
  """
54
+ inputs = data['inputs']
55
+
56
+ # Check cache for inputs
57
+ uncached_inputs, cached_results = self._lookup_cache(inputs)
58
+
59
+ output_results = {}
60
+
61
+ # Call model for uncached input
62
+ if len(uncached_inputs) != 0:
63
+ model_inputs = [input_string for _, input_string in uncached_inputs]
64
+ uncached_inputs_tokenized = self.tokenizer(model_inputs, return_tensors = "pt", padding = True)
65
+
66
+ with torch.no_grad():
67
+ uncached_output_tensor = self.model(**uncached_inputs_tokenized)
68
+
69
+ uncached_output_list = uncached_output_tensor.pooler_output.tolist()
70
+
71
+ # combine cached and uncached results
72
+ for (index, input_string), result in zip(uncached_inputs, uncached_output_list):
73
+ self._store_to_cache(input_string, result)
74
+ output_results[index] = result
75
 
76
+ self._cleanup_cache()
77
 
78
+ output_results.update(cached_results)
 
79
 
80
+ return [output_results[i] for i in range(len(inputs))]