Joshua Lochner commited on
Commit
23a1215
·
1 Parent(s): a294fb2

Use partial functions to allow pickling of prediction functions

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -1,4 +1,5 @@
1
 
 
2
  from math import ceil, floor
3
  import streamlit.components.v1 as components
4
  from transformers import (
@@ -86,6 +87,16 @@ def download_classifier(classifier_args):
86
  return True
87
 
88
 
 
 
 
 
 
 
 
 
 
 
89
  @st.cache(persist=True, allow_output_mutation=True)
90
  def load_predict(model_id):
91
  model_info = MODELS[model_id]
@@ -102,16 +113,7 @@ def load_predict(model_id):
102
 
103
  download_classifier(classifier_args)
104
 
105
- def predict_function(video_id):
106
- if video_id not in prediction_cache[model_id]:
107
- prediction_cache[model_id][video_id] = pred(
108
- video_id, model, tokenizer,
109
- segmentation_args=segmentation_args,
110
- classifier_args=classifier_args
111
- )
112
- return prediction_cache[model_id][video_id]
113
-
114
- return predict_function
115
 
116
 
117
  def main():
@@ -192,5 +194,6 @@ def main():
192
  wiki_link = '[Review generated segments before submitting!](https://wiki.sponsor.ajay.app/w/Automating_Submissions)'
193
  st.markdown(wiki_link, unsafe_allow_html=True)
194
 
 
195
  if __name__ == '__main__':
196
  main()
 
1
 
2
+ from functools import partial
3
  from math import ceil, floor
4
  import streamlit.components.v1 as components
5
  from transformers import (
 
87
  return True
88
 
89
 
90
+ def predict_function(model_id, model, tokenizer, segmentation_args, classifier_args, video_id):
91
+ if video_id not in prediction_cache[model_id]:
92
+ prediction_cache[model_id][video_id] = pred(
93
+ video_id, model, tokenizer,
94
+ segmentation_args=segmentation_args,
95
+ classifier_args=classifier_args
96
+ )
97
+ return prediction_cache[model_id][video_id]
98
+
99
+
100
  @st.cache(persist=True, allow_output_mutation=True)
101
  def load_predict(model_id):
102
  model_info = MODELS[model_id]
 
113
 
114
  download_classifier(classifier_args)
115
 
116
+ return partial(predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  def main():
 
194
  wiki_link = '[Review generated segments before submitting!](https://wiki.sponsor.ajay.app/w/Automating_Submissions)'
195
  st.markdown(wiki_link, unsafe_allow_html=True)
196
 
197
+
198
  if __name__ == '__main__':
199
  main()