Joshua Lochner commited on
Commit
813b772
1 Parent(s): 490a61c

Improve prediction pipeline

Browse files
Files changed (2) hide show
  1. src/predict.py +37 -24
  2. src/shared.py +8 -5
src/predict.py CHANGED
@@ -135,43 +135,57 @@ def greedy_match(list, sublist):
135
  return best_i, best_j, best_k
136
 
137
 
138
- def predict_sponsor_text(text, model, tokenizer):
 
 
 
 
 
139
  """Given a body of text, predict the words which are part of the sponsor"""
140
  model_device = next(model.parameters()).device
141
- input_ids = tokenizer(
142
- f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(model_device)
143
 
144
- max_out_len = round(min(
145
- max(
146
- len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE,
147
- len(input_ids[0]) + MIN_SAFETY_TOKENS
148
- ),
149
- model.model_dim))
150
- outputs = model.generate(input_ids, max_length=max_out_len)
151
 
152
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
153
 
 
 
 
154
 
155
- def predict_sponsor_matches(text, model, tokenizer):
156
- sponsorship_text = predict_sponsor_text(text, model, tokenizer)
157
- return extract_sponsor_matches(sponsorship_text)
158
 
159
 
160
  def segments_to_predictions(segments, model, tokenizer):
161
  predicted_time_ranges = []
162
 
163
- # TODO pass to model simultaneously, not in for loop
164
- # use 2d array for input ids
165
- for segment in segments:
166
- cleaned_batch = [preprocess.clean_text(
167
- word['text']) for word in segment]
168
- batch_text = ' '.join(cleaned_batch)
 
169
 
170
- matches = predict_sponsor_matches(batch_text, model, tokenizer)
 
 
 
 
171
 
172
- for match in matches:
173
  matched_text = match['text'].split()
174
- # TODO skip if too short
175
 
176
  i1, j1, k1 = greedy_match(
177
  cleaned_batch, matched_text[:MATCH_WINDOW])
@@ -179,7 +193,6 @@ def segments_to_predictions(segments, model, tokenizer):
179
  cleaned_batch, matched_text[-MATCH_WINDOW:])
180
 
181
  extracted_words = segment[i1:i2+k2]
182
-
183
  if not extracted_words:
184
  continue
185
 
 
135
  return best_i, best_j, best_k
136
 
137
 
138
+ def predict_sponsor_from_texts(texts, model, tokenizer):
139
+ clean_texts = list(map(preprocess.clean_text, texts))
140
+ return predict_sponsor_from_cleaned_texts(clean_texts, model, tokenizer)
141
+
142
+
143
+ def predict_sponsor_from_cleaned_texts(cleaned_texts, model, tokenizer):
144
  """Given a body of text, predict the words which are part of the sponsor"""
145
  model_device = next(model.parameters()).device
 
 
146
 
147
+ decoded_outputs = []
148
+ # Do individually, to avoid running out of memory for long videos
149
+ for cleaned_words in cleaned_texts:
150
+ text = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value + \
151
+ ' '.join(cleaned_words)
152
+ input_ids = tokenizer(text, return_tensors='pt',
153
+ truncation=True).input_ids.to(model_device)
154
 
155
+ # Optimise output length so that we do not generate unnecessarily long texts
156
+ max_out_len = round(min(
157
+ max(
158
+ len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE,
159
+ len(input_ids[0]) + MIN_SAFETY_TOKENS
160
+ ),
161
+ model.model_dim)
162
+ )
163
 
164
+ outputs = model.generate(input_ids, max_length=max_out_len)
165
+ decoded_outputs.append(tokenizer.decode(
166
+ outputs[0], skip_special_tokens=True))
167
 
168
+ return decoded_outputs
 
 
169
 
170
 
171
  def segments_to_predictions(segments, model, tokenizer):
172
  predicted_time_ranges = []
173
 
174
+ cleaned_texts = [
175
+ [x['cleaned'] for x in cleaned_segment]
176
+ for cleaned_segment in segments
177
+ ]
178
+
179
+ sponsorship_texts = predict_sponsor_from_cleaned_texts(
180
+ cleaned_texts, model, tokenizer)
181
 
182
+ matches = extract_sponsor_matches(sponsorship_texts)
183
+
184
+ for segment_matches, cleaned_batch, segment in zip(matches, cleaned_texts, segments):
185
+
186
+ for match in segment_matches: # one segment might contain multiple sponsors/ir/selfpromos
187
 
 
188
  matched_text = match['text'].split()
 
189
 
190
  i1, j1, k1 = greedy_match(
191
  cleaned_batch, matched_text[:MATCH_WINDOW])
 
193
  cleaned_batch, matched_text[-MATCH_WINDOW:])
194
 
195
  extracted_words = segment[i1:i2+k2]
 
196
  if not extracted_words:
197
  continue
198
 
src/shared.py CHANGED
@@ -76,11 +76,14 @@ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
76
  SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
77
 
78
 
79
- def extract_sponsor_matches(text):
80
- if CustomTokens.NO_SEGMENT.value in text:
81
- return []
82
-
83
- return re_findall(SEGMENT_MATCH_RE, text)
 
 
 
84
 
85
 
86
  @dataclass
 
76
  SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
77
 
78
 
79
+ def extract_sponsor_matches(texts):
80
+ to_return = []
81
+ for text in texts:
82
+ if CustomTokens.NO_SEGMENT.value in text:
83
+ to_return.append([])
84
+ else:
85
+ to_return.append(re_findall(SEGMENT_MATCH_RE, text))
86
+ return to_return
87
 
88
 
89
  @dataclass