Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
813b772
1
Parent(s):
490a61c
Improve prediction pipeline
Browse files- src/predict.py +37 -24
- src/shared.py +8 -5
src/predict.py
CHANGED
@@ -135,43 +135,57 @@ def greedy_match(list, sublist):
|
|
135 |
return best_i, best_j, best_k
|
136 |
|
137 |
|
138 |
-
def
|
|
|
|
|
|
|
|
|
|
|
139 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
140 |
model_device = next(model.parameters()).device
|
141 |
-
input_ids = tokenizer(
|
142 |
-
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(model_device)
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
157 |
-
return extract_sponsor_matches(sponsorship_text)
|
158 |
|
159 |
|
160 |
def segments_to_predictions(segments, model, tokenizer):
|
161 |
predicted_time_ranges = []
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
169 |
|
170 |
-
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
for match in matches:
|
173 |
matched_text = match['text'].split()
|
174 |
-
# TODO skip if too short
|
175 |
|
176 |
i1, j1, k1 = greedy_match(
|
177 |
cleaned_batch, matched_text[:MATCH_WINDOW])
|
@@ -179,7 +193,6 @@ def segments_to_predictions(segments, model, tokenizer):
|
|
179 |
cleaned_batch, matched_text[-MATCH_WINDOW:])
|
180 |
|
181 |
extracted_words = segment[i1:i2+k2]
|
182 |
-
|
183 |
if not extracted_words:
|
184 |
continue
|
185 |
|
|
|
135 |
return best_i, best_j, best_k
|
136 |
|
137 |
|
138 |
+
def predict_sponsor_from_texts(texts, model, tokenizer):
|
139 |
+
clean_texts = list(map(preprocess.clean_text, texts))
|
140 |
+
return predict_sponsor_from_cleaned_texts(clean_texts, model, tokenizer)
|
141 |
+
|
142 |
+
|
143 |
+
def predict_sponsor_from_cleaned_texts(cleaned_texts, model, tokenizer):
|
144 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
145 |
model_device = next(model.parameters()).device
|
|
|
|
|
146 |
|
147 |
+
decoded_outputs = []
|
148 |
+
# Do individually, to avoid running out of memory for long videos
|
149 |
+
for cleaned_words in cleaned_texts:
|
150 |
+
text = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value + \
|
151 |
+
' '.join(cleaned_words)
|
152 |
+
input_ids = tokenizer(text, return_tensors='pt',
|
153 |
+
truncation=True).input_ids.to(model_device)
|
154 |
|
155 |
+
# Optimise output length so that we do not generate unnecessarily long texts
|
156 |
+
max_out_len = round(min(
|
157 |
+
max(
|
158 |
+
len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE,
|
159 |
+
len(input_ids[0]) + MIN_SAFETY_TOKENS
|
160 |
+
),
|
161 |
+
model.model_dim)
|
162 |
+
)
|
163 |
|
164 |
+
outputs = model.generate(input_ids, max_length=max_out_len)
|
165 |
+
decoded_outputs.append(tokenizer.decode(
|
166 |
+
outputs[0], skip_special_tokens=True))
|
167 |
|
168 |
+
return decoded_outputs
|
|
|
|
|
169 |
|
170 |
|
171 |
def segments_to_predictions(segments, model, tokenizer):
|
172 |
predicted_time_ranges = []
|
173 |
|
174 |
+
cleaned_texts = [
|
175 |
+
[x['cleaned'] for x in cleaned_segment]
|
176 |
+
for cleaned_segment in segments
|
177 |
+
]
|
178 |
+
|
179 |
+
sponsorship_texts = predict_sponsor_from_cleaned_texts(
|
180 |
+
cleaned_texts, model, tokenizer)
|
181 |
|
182 |
+
matches = extract_sponsor_matches(sponsorship_texts)
|
183 |
+
|
184 |
+
for segment_matches, cleaned_batch, segment in zip(matches, cleaned_texts, segments):
|
185 |
+
|
186 |
+
for match in segment_matches: # one segment might contain multiple sponsors/ir/selfpromos
|
187 |
|
|
|
188 |
matched_text = match['text'].split()
|
|
|
189 |
|
190 |
i1, j1, k1 = greedy_match(
|
191 |
cleaned_batch, matched_text[:MATCH_WINDOW])
|
|
|
193 |
cleaned_batch, matched_text[-MATCH_WINDOW:])
|
194 |
|
195 |
extracted_words = segment[i1:i2+k2]
|
|
|
196 |
if not extracted_words:
|
197 |
continue
|
198 |
|
src/shared.py
CHANGED
@@ -76,11 +76,14 @@ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
|
|
76 |
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
77 |
|
78 |
|
79 |
-
def extract_sponsor_matches(
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
84 |
|
85 |
|
86 |
@dataclass
|
|
|
76 |
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
77 |
|
78 |
|
79 |
+
def extract_sponsor_matches(texts):
|
80 |
+
to_return = []
|
81 |
+
for text in texts:
|
82 |
+
if CustomTokens.NO_SEGMENT.value in text:
|
83 |
+
to_return.append([])
|
84 |
+
else:
|
85 |
+
to_return.append(re_findall(SEGMENT_MATCH_RE, text))
|
86 |
+
return to_return
|
87 |
|
88 |
|
89 |
@dataclass
|