Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
bb58e90
1
Parent(s):
f9281a4
Move `seconds_to_time` to shared
Browse files- src/predict.py +18 -24
- src/shared.py +14 -1
src/predict.py
CHANGED
@@ -5,7 +5,8 @@ from typing import Optional
|
|
5 |
from segment import (
|
6 |
generate_segments,
|
7 |
extract_segment,
|
8 |
-
|
|
|
9 |
CustomTokens,
|
10 |
word_start,
|
11 |
word_end,
|
@@ -13,7 +14,7 @@ from segment import (
|
|
13 |
)
|
14 |
import preprocess
|
15 |
from errors import TranscriptError
|
16 |
-
from model import get_classifier_vectorizer
|
17 |
from transformers import (
|
18 |
AutoModelForSeq2SeqLM,
|
19 |
AutoTokenizer,
|
@@ -26,25 +27,15 @@ import logging
|
|
26 |
|
27 |
import re
|
28 |
|
29 |
-
|
30 |
-
def seconds_to_time(seconds, remove_leading_zeroes=False):
|
31 |
-
fractional = round(seconds % 1, 3)
|
32 |
-
fractional = '' if fractional == 0 else str(fractional)[1:]
|
33 |
-
h, remainder = divmod(abs(int(seconds)), 3600)
|
34 |
-
m, s = divmod(remainder, 60)
|
35 |
-
hms = f'{h:02}:{m:02}:{s:02}'
|
36 |
-
if remove_leading_zeroes:
|
37 |
-
hms = re.sub(r'^0(?:0:0?)?', '', hms)
|
38 |
-
return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
|
39 |
-
|
40 |
-
|
41 |
@dataclass
|
42 |
class TrainingOutputArguments:
|
43 |
|
44 |
model_path: str = field(
|
45 |
default=None,
|
46 |
metadata={
|
47 |
-
'help': 'Path to pretrained model used for prediction'
|
|
|
48 |
)
|
49 |
|
50 |
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
@@ -106,7 +97,8 @@ class ClassifierArguments:
|
|
106 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
107 |
|
108 |
|
109 |
-
|
|
|
110 |
"""Use classifier to filter predictions"""
|
111 |
if not predictions:
|
112 |
return predictions
|
@@ -135,7 +127,7 @@ def filter_and_add_probabilities(predictions, classifier_args): # classifier, v
|
|
135 |
continue # Ignore
|
136 |
|
137 |
if (prediction['category'] not in predicted_probabilities) \
|
138 |
-
|
139 |
# Unknown category or we are confident enough to overrule,
|
140 |
# so change category to what was predicted by classifier
|
141 |
prediction['category'] = classifier_category
|
@@ -175,7 +167,8 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
175 |
|
176 |
# TODO add back
|
177 |
if classifier_args is not None:
|
178 |
-
predictions = filter_and_add_probabilities(
|
|
|
179 |
|
180 |
return predictions
|
181 |
|
@@ -205,8 +198,12 @@ def predict_sponsor_text(text, model, tokenizer):
|
|
205 |
input_ids = tokenizer(
|
206 |
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
|
207 |
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
210 |
outputs = model.generate(input_ids, max_length=max_out_len)
|
211 |
|
212 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
@@ -300,10 +297,7 @@ def main():
|
|
300 |
print('No video ID supplied. Use `--video_id`.')
|
301 |
return
|
302 |
|
303 |
-
model =
|
304 |
-
model.to(device())
|
305 |
-
|
306 |
-
tokenizer = AutoTokenizer.from_pretrained(predict_args.model_path)
|
307 |
|
308 |
predict_args.video_id = predict_args.video_id.strip()
|
309 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
|
|
5 |
from segment import (
|
6 |
generate_segments,
|
7 |
extract_segment,
|
8 |
+
MIN_SAFETY_TOKENS,
|
9 |
+
SAFETY_TOKENS_PERCENTAGE,
|
10 |
CustomTokens,
|
11 |
word_start,
|
12 |
word_end,
|
|
|
14 |
)
|
15 |
import preprocess
|
16 |
from errors import TranscriptError
|
17 |
+
from model import get_classifier_vectorizer, get_model_tokenizer
|
18 |
from transformers import (
|
19 |
AutoModelForSeq2SeqLM,
|
20 |
AutoTokenizer,
|
|
|
27 |
|
28 |
import re
|
29 |
|
30 |
+
from shared import seconds_to_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
@dataclass
|
32 |
class TrainingOutputArguments:
|
33 |
|
34 |
model_path: str = field(
|
35 |
default=None,
|
36 |
metadata={
|
37 |
+
'help': 'Path to pretrained model used for prediction'
|
38 |
+
}
|
39 |
)
|
40 |
|
41 |
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
|
|
97 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
98 |
|
99 |
|
100 |
+
# classifier, vectorizer,
|
101 |
+
def filter_and_add_probabilities(predictions, classifier_args):
|
102 |
"""Use classifier to filter predictions"""
|
103 |
if not predictions:
|
104 |
return predictions
|
|
|
127 |
continue # Ignore
|
128 |
|
129 |
if (prediction['category'] not in predicted_probabilities) \
|
130 |
+
or (classifier_category is not None and classifier_probability > 0.5): # TODO make param
|
131 |
# Unknown category or we are confident enough to overrule,
|
132 |
# so change category to what was predicted by classifier
|
133 |
prediction['category'] = classifier_category
|
|
|
167 |
|
168 |
# TODO add back
|
169 |
if classifier_args is not None:
|
170 |
+
predictions = filter_and_add_probabilities(
|
171 |
+
predictions, classifier_args)
|
172 |
|
173 |
return predictions
|
174 |
|
|
|
198 |
input_ids = tokenizer(
|
199 |
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
|
200 |
|
201 |
+
max_out_len = round(min(
|
202 |
+
max(
|
203 |
+
len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE,
|
204 |
+
len(input_ids[0]) + MIN_SAFETY_TOKENS
|
205 |
+
),
|
206 |
+
model.model_dim))
|
207 |
outputs = model.generate(input_ids, max_length=max_out_len)
|
208 |
|
209 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
297 |
print('No video ID supplied. Use `--video_id`.')
|
298 |
return
|
299 |
|
300 |
+
model, tokenizer = get_model_tokenizer(predict_args.model_path)
|
|
|
|
|
|
|
301 |
|
302 |
predict_args.video_id = predict_args.video_id.strip()
|
303 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
src/shared.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gc
|
2 |
from time import time_ns
|
3 |
import random
|
@@ -11,6 +12,7 @@ from enum import Enum
|
|
11 |
START_SEGMENT_TEMPLATE = 'START_{}_TOKEN'
|
12 |
END_SEGMENT_TEMPLATE = 'END_{}_TOKEN'
|
13 |
|
|
|
14 |
class CustomTokens(Enum):
|
15 |
EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
|
16 |
|
@@ -29,7 +31,7 @@ class CustomTokens(Enum):
|
|
29 |
LAUGHTER = '[Laughter]'
|
30 |
|
31 |
PROFANITY = 'PROFANITY_TOKEN'
|
32 |
-
|
33 |
# Segment tokens
|
34 |
NO_SEGMENT = 'NO_SEGMENT_TOKEN'
|
35 |
|
@@ -103,6 +105,17 @@ def device():
|
|
103 |
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
def reset():
|
107 |
torch.clear_autocast_cache()
|
108 |
torch.cuda.empty_cache()
|
|
|
1 |
+
import re
|
2 |
import gc
|
3 |
from time import time_ns
|
4 |
import random
|
|
|
12 |
START_SEGMENT_TEMPLATE = 'START_{}_TOKEN'
|
13 |
END_SEGMENT_TEMPLATE = 'END_{}_TOKEN'
|
14 |
|
15 |
+
|
16 |
class CustomTokens(Enum):
|
17 |
EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
|
18 |
|
|
|
31 |
LAUGHTER = '[Laughter]'
|
32 |
|
33 |
PROFANITY = 'PROFANITY_TOKEN'
|
34 |
+
|
35 |
# Segment tokens
|
36 |
NO_SEGMENT = 'NO_SEGMENT_TOKEN'
|
37 |
|
|
|
105 |
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
106 |
|
107 |
|
108 |
+
def seconds_to_time(seconds, remove_leading_zeroes=False):
|
109 |
+
fractional = round(seconds % 1, 3)
|
110 |
+
fractional = '' if fractional == 0 else str(fractional)[1:]
|
111 |
+
h, remainder = divmod(abs(int(seconds)), 3600)
|
112 |
+
m, s = divmod(remainder, 60)
|
113 |
+
hms = f'{h:02}:{m:02}:{s:02}'
|
114 |
+
if remove_leading_zeroes:
|
115 |
+
hms = re.sub(r'^0(?:0:0?)?', '', hms)
|
116 |
+
return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
|
117 |
+
|
118 |
+
|
119 |
def reset():
|
120 |
torch.clear_autocast_cache()
|
121 |
torch.cuda.empty_cache()
|