Joshua Lochner commited on
Commit
f11d2c2
·
1 Parent(s): 14ea568
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -73,10 +73,10 @@ CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
73
 
74
  @st.cache(allow_output_mutation=True)
75
  def load_predict(model_id):
76
- model = MODELS[model_id]
77
 
78
  # Use default segmentation and classification arguments
79
- evaluation_args = EvaluationArguments(model_path=model['repo_id'])
80
  segmentation_args = SegmentationArguments()
81
  classifier_args = ClassifierArguments()
82
 
@@ -98,13 +98,13 @@ def load_predict(model_id):
98
  )
99
 
100
  def predict_function(video_id):
101
- if video_id not in model['cache']:
102
- model['cache'][video_id] = pred(
103
  video_id, model, tokenizer,
104
  segmentation_args=segmentation_args,
105
  classifier_args=classifier_args
106
  )
107
- return model['cache'][video_id]
108
 
109
  return predict_function
110
 
 
73
 
74
  @st.cache(allow_output_mutation=True)
75
  def load_predict(model_id):
76
+ model_info = MODELS[model_id]
77
 
78
  # Use default segmentation and classification arguments
79
+ evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
80
  segmentation_args = SegmentationArguments()
81
  classifier_args = ClassifierArguments()
82
 
 
98
  )
99
 
100
  def predict_function(video_id):
101
+ if video_id not in model_info['cache']:
102
+ model_info['cache'][video_id] = pred(
103
  video_id, model, tokenizer,
104
  segmentation_args=segmentation_args,
105
  classifier_args=classifier_args
106
  )
107
+ return model_info['cache'][video_id]
108
 
109
  return predict_function
110