Joshua Lochner commited on
Commit
004109e
·
1 Parent(s): f11d2c2

Fix shared prediction cache

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -42,26 +42,29 @@ st.set_page_config(
42
  def persistdata():
43
  return {}
44
 
 
45
 
46
  MODELS = {
47
  'Small (77M)': {
48
  'pretrained': 'google/t5-v1_1-small',
49
  'repo_id': 'Xenova/sponsorblock-small',
50
- 'cache': persistdata()
51
  },
52
  'Base v1 (220M)': {
53
  'pretrained': 't5-base',
54
  'repo_id': 'EColi/sponsorblock-base-v1',
55
- 'cache': persistdata()
56
  },
57
 
58
  'Base v1.1 (250M)': {
59
  'pretrained': 'google/t5-v1_1-base',
60
  'repo_id': 'Xenova/sponsorblock-base',
61
- 'cache': persistdata()
62
  }
63
  }
64
 
 
 
 
 
 
65
  CATGEGORY_OPTIONS = {
66
  'SPONSOR': 'Sponsor',
67
  'SELFPROMO': 'Self/unpaid promo',
@@ -98,13 +101,13 @@ def load_predict(model_id):
98
  )
99
 
100
  def predict_function(video_id):
101
- if video_id not in model_info['cache']:
102
- model_info['cache'][video_id] = pred(
103
  video_id, model, tokenizer,
104
  segmentation_args=segmentation_args,
105
  classifier_args=classifier_args
106
  )
107
- return model_info['cache'][video_id]
108
 
109
  return predict_function
110
 
 
42
  def persistdata():
43
  return {}
44
 
45
+ prediction_cache = persistdata()
46
 
47
  MODELS = {
48
  'Small (77M)': {
49
  'pretrained': 'google/t5-v1_1-small',
50
  'repo_id': 'Xenova/sponsorblock-small',
 
51
  },
52
  'Base v1 (220M)': {
53
  'pretrained': 't5-base',
54
  'repo_id': 'EColi/sponsorblock-base-v1',
 
55
  },
56
 
57
  'Base v1.1 (250M)': {
58
  'pretrained': 'google/t5-v1_1-base',
59
  'repo_id': 'Xenova/sponsorblock-base',
 
60
  }
61
  }
62
 
63
+ # Create per-model cache
64
+ for m in MODELS:
65
+ if m not in prediction_cache:
66
+ prediction_cache[m] = {}
67
+
68
  CATGEGORY_OPTIONS = {
69
  'SPONSOR': 'Sponsor',
70
  'SELFPROMO': 'Self/unpaid promo',
 
101
  )
102
 
103
  def predict_function(video_id):
104
+ if video_id not in prediction_cache[model_id]:
105
+ prediction_cache[model_id][video_id] = pred(
106
  video_id, model, tokenizer,
107
  segmentation_args=segmentation_args,
108
  classifier_args=classifier_args
109
  )
110
+ return prediction_cache[model_id][video_id]
111
 
112
  return predict_function
113