Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
·
14ea568
1
Parent(s):
e68b946
Update prediction caching system to store predictions per model
Browse files
app.py
CHANGED
@@ -35,21 +35,31 @@ st.set_page_config(
|
|
35 |
|
36 |
# https://huggingface.co/docs/transformers/model_doc/t5
|
37 |
# https://huggingface.co/docs/transformers/model_doc/t5v1.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
MODELS = {
|
39 |
'Small (77M)': {
|
40 |
'pretrained': 'google/t5-v1_1-small',
|
41 |
'repo_id': 'Xenova/sponsorblock-small',
|
|
|
42 |
},
|
43 |
'Base v1 (220M)': {
|
44 |
'pretrained': 't5-base',
|
45 |
'repo_id': 'EColi/sponsorblock-base-v1',
|
|
|
46 |
},
|
47 |
|
48 |
'Base v1.1 (250M)': {
|
49 |
'pretrained': 'google/t5-v1_1-base',
|
50 |
'repo_id': 'Xenova/sponsorblock-base',
|
|
|
51 |
}
|
52 |
-
|
53 |
}
|
54 |
|
55 |
CATGEGORY_OPTIONS = {
|
@@ -62,18 +72,11 @@ CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
|
|
62 |
|
63 |
|
64 |
@st.cache(allow_output_mutation=True)
|
65 |
-
def
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
# Faster caching system for predictions (No need to hash)
|
70 |
-
predictions_cache = persistdata()
|
71 |
-
|
72 |
-
|
73 |
-
@st.cache(allow_output_mutation=True)
|
74 |
-
def load_predict(model_path):
|
75 |
# Use default segmentation and classification arguments
|
76 |
-
evaluation_args = EvaluationArguments(model_path=
|
77 |
segmentation_args = SegmentationArguments()
|
78 |
classifier_args = ClassifierArguments()
|
79 |
|
@@ -95,13 +98,13 @@ def load_predict(model_path):
|
|
95 |
)
|
96 |
|
97 |
def predict_function(video_id):
|
98 |
-
if video_id not in
|
99 |
-
|
100 |
video_id, model, tokenizer,
|
101 |
segmentation_args=segmentation_args,
|
102 |
classifier_args=classifier_args
|
103 |
)
|
104 |
-
return
|
105 |
|
106 |
return predict_function
|
107 |
|
@@ -115,7 +118,7 @@ def main():
|
|
115 |
model_id = st.selectbox('Select model', MODELS.keys(), index=0)
|
116 |
|
117 |
# Load prediction function
|
118 |
-
predict = load_predict(
|
119 |
|
120 |
video_id = st.text_input('Video ID:') # , placeholder='e.g., axtQvkSpoto'
|
121 |
|
|
|
35 |
|
36 |
# https://huggingface.co/docs/transformers/model_doc/t5
|
37 |
# https://huggingface.co/docs/transformers/model_doc/t5v1.1
|
38 |
+
|
39 |
+
|
40 |
+
# Faster caching system for predictions (No need to hash)
|
41 |
+
@st.cache(allow_output_mutation=True)
|
42 |
+
def persistdata():
|
43 |
+
return {}
|
44 |
+
|
45 |
+
|
46 |
MODELS = {
|
47 |
'Small (77M)': {
|
48 |
'pretrained': 'google/t5-v1_1-small',
|
49 |
'repo_id': 'Xenova/sponsorblock-small',
|
50 |
+
'cache': persistdata()
|
51 |
},
|
52 |
'Base v1 (220M)': {
|
53 |
'pretrained': 't5-base',
|
54 |
'repo_id': 'EColi/sponsorblock-base-v1',
|
55 |
+
'cache': persistdata()
|
56 |
},
|
57 |
|
58 |
'Base v1.1 (250M)': {
|
59 |
'pretrained': 'google/t5-v1_1-base',
|
60 |
'repo_id': 'Xenova/sponsorblock-base',
|
61 |
+
'cache': persistdata()
|
62 |
}
|
|
|
63 |
}
|
64 |
|
65 |
CATGEGORY_OPTIONS = {
|
|
|
72 |
|
73 |
|
74 |
@st.cache(allow_output_mutation=True)
|
75 |
+
def load_predict(model_id):
|
76 |
+
model = MODELS[model_id]
|
|
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
# Use default segmentation and classification arguments
|
79 |
+
evaluation_args = EvaluationArguments(model_path=model['repo_id'])
|
80 |
segmentation_args = SegmentationArguments()
|
81 |
classifier_args = ClassifierArguments()
|
82 |
|
|
|
98 |
)
|
99 |
|
100 |
def predict_function(video_id):
|
101 |
+
if video_id not in model['cache']:
|
102 |
+
model['cache'][video_id] = pred(
|
103 |
video_id, model, tokenizer,
|
104 |
segmentation_args=segmentation_args,
|
105 |
classifier_args=classifier_args
|
106 |
)
|
107 |
+
return model['cache'][video_id]
|
108 |
|
109 |
return predict_function
|
110 |
|
|
|
118 |
model_id = st.selectbox('Select model', MODELS.keys(), index=0)
|
119 |
|
120 |
# Load prediction function
|
121 |
+
predict = load_predict(model_id)
|
122 |
|
123 |
video_id = st.text_input('Video ID:') # , placeholder='e.g., axtQvkSpoto'
|
124 |
|