Joshua Lochner commited on
Commit
8b71088
·
1 Parent(s): a9123fa

Abstract inference code

Browse files
Files changed (2) hide show
  1. src/evaluate.py +17 -92
  2. src/predict.py +146 -42
src/evaluate.py CHANGED
@@ -1,13 +1,10 @@
1
- import itertools
2
- import base64
3
- import re
4
- import requests
5
  from model import get_model_tokenizer
6
  from utils import jaccard
7
  from transformers import HfArgumentParser
8
  from preprocess import DatasetArguments, get_words
9
  from shared import GeneralArguments
10
- from predict import ClassifierArguments, predict, TrainingOutputArguments
11
  from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
12
  import pandas as pd
13
  from dataclasses import dataclass, field
@@ -21,18 +18,8 @@ from urllib.parse import quote
21
 
22
 
23
  @dataclass
24
- class EvaluationArguments(TrainingOutputArguments):
25
- """
26
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
27
- """
28
- max_videos: Optional[int] = field(
29
- default=None,
30
- metadata={
31
- 'help': 'The number of videos to test on'
32
- }
33
- )
34
- start_index: int = field(default=None, metadata={
35
- 'help': 'Video to start the evaluation at.'})
36
  output_file: Optional[str] = field(
37
  default='metrics.csv',
38
  metadata={
@@ -40,13 +27,6 @@ class EvaluationArguments(TrainingOutputArguments):
40
  }
41
  )
42
 
43
- channel_id: Optional[str] = field(
44
- default=None,
45
- metadata={
46
- 'help': 'Used to evaluate a channel'
47
- }
48
- )
49
-
50
 
51
  def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
52
  """Attach sponsor segments to closest prediction"""
@@ -144,56 +124,6 @@ def calculate_metrics(labelled_words, predictions):
144
  return metrics
145
 
146
 
147
- # Public innertube key (b64 encoded so that it is not incorrectly flagged)
148
- INNERTUBE_KEY = base64.b64decode(
149
- b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
150
-
151
- YT_CONTEXT = {
152
- 'client': {
153
- 'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
154
- 'clientName': 'WEB',
155
- 'clientVersion': '2.20211221.00.00',
156
- }
157
- }
158
- _YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
159
-
160
-
161
- def get_all_channel_vids(channel_id):
162
- continuation = None
163
- while True:
164
- if continuation is None:
165
- params = {'list': channel_id.replace('UC', 'UU', 1)}
166
- response = requests.get(
167
- 'https://www.youtube.com/playlist', params=params)
168
- items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
169
- 'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
170
- else:
171
- params = {'key': INNERTUBE_KEY}
172
- data = {
173
- 'context': YT_CONTEXT,
174
- 'continuation': continuation
175
- }
176
- response = requests.post(
177
- 'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
178
- items = response.json()[
179
- 'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
180
-
181
- new_token = None
182
- for vid in items:
183
- info = vid.get('playlistVideoRenderer')
184
- if info:
185
- yield info['videoId']
186
- continue
187
-
188
- info = vid.get('continuationItemRenderer')
189
- if info:
190
- new_token = info['continuationEndpoint']['continuationCommand']['token']
191
-
192
- if new_token is None:
193
- break
194
- continuation = new_token
195
-
196
-
197
  def main():
198
  hf_parser = HfArgumentParser((
199
  EvaluationArguments,
@@ -205,30 +135,25 @@ def main():
205
 
206
  evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
207
 
208
- model, tokenizer = get_model_tokenizer(evaluation_args.model_path, evaluation_args.cache_dir)
209
-
210
- # # TODO find better way of evaluating videos not trained on
211
- # dataset = load_dataset('json', data_files=os.path.join(
212
- # dataset_args.data_dir, dataset_args.validation_file))['train']
213
- # video_ids = [row['video_id'] for row in dataset]
214
-
215
  # Load labelled data:
216
  final_path = os.path.join(
217
- dataset_args.data_dir, dataset_args.processed_file)
 
 
 
 
 
 
 
 
218
 
219
  with open(final_path) as fp:
220
  final_data = json.load(fp)
221
 
222
- if evaluation_args.channel_id is not None:
223
- start = evaluation_args.start_index or 0
224
- end = None if evaluation_args.max_videos is None else start + \
225
- evaluation_args.max_videos
226
-
227
- video_ids = list(itertools.islice(get_all_channel_vids(
228
- evaluation_args.channel_id), start, end))
229
- print('Found', len(video_ids), 'for channel', evaluation_args.channel_id)
230
 
231
- else:
232
  video_ids = list(final_data.keys())
233
  random.shuffle(video_ids)
234
 
@@ -255,7 +180,7 @@ def main():
255
 
256
  sponsor_segments = final_data.get(video_id)
257
  if not sponsor_segments:
258
- # TODO remove - parse using whole database
259
  continue
260
 
261
  words = get_words(video_id)
 
1
+
 
 
 
2
  from model import get_model_tokenizer
3
  from utils import jaccard
4
  from transformers import HfArgumentParser
5
  from preprocess import DatasetArguments, get_words
6
  from shared import GeneralArguments
7
+ from predict import ClassifierArguments, predict, InferenceArguments
8
  from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
9
  import pandas as pd
10
  from dataclasses import dataclass, field
 
18
 
19
 
20
  @dataclass
21
+ class EvaluationArguments(InferenceArguments):
22
+ """Arguments pertaining to how evaluation will occur."""
 
 
 
 
 
 
 
 
 
 
23
  output_file: Optional[str] = field(
24
  default='metrics.csv',
25
  metadata={
 
27
  }
28
  )
29
 
 
 
 
 
 
 
 
30
 
31
  def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
32
  """Attach sponsor segments to closest prediction"""
 
124
  return metrics
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def main():
128
  hf_parser = HfArgumentParser((
129
  EvaluationArguments,
 
135
 
136
  evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
137
 
 
 
 
 
 
 
 
138
  # Load labelled data:
139
  final_path = os.path.join(
140
+ dataset_args.data_dir, dataset_args.processed_database)
141
+
142
+ if not os.path.exists(final_path):
143
+ print('ERROR: Processed database not found.',
144
+ f'Run `python src/preprocess.py --update_database --do_process_database` to generate "{final_path}".')
145
+ return
146
+
147
+ model, tokenizer = get_model_tokenizer(
148
+ evaluation_args.model_path, evaluation_args.cache_dir)
149
 
150
  with open(final_path) as fp:
151
  final_data = json.load(fp)
152
 
153
+ if evaluation_args.video_ids: # Use specified
154
+ video_ids = evaluation_args.video_ids
 
 
 
 
 
 
155
 
156
+ else: # Use items found in preprocessed database
157
  video_ids = list(final_data.keys())
158
  random.shuffle(video_ids)
159
 
 
180
 
181
  sponsor_segments = final_data.get(video_id)
182
  if not sponsor_segments:
183
+ print('No labels found for', video_id)
184
  continue
185
 
186
  words = get_words(video_id)
src/predict.py CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  from utils import re_findall
2
  from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, OutputArguments, device, seconds_to_time
3
  from typing import Optional
@@ -11,17 +22,62 @@ from segment import (
11
  SegmentationArguments
12
  )
13
  import preprocess
14
- from errors import TranscriptError, ModelLoadError, ClassifierLoadError
15
  from model import ModelArguments, get_classifier_vectorizer, get_model_tokenizer
16
- from transformers import HfArgumentParser
17
- from transformers.trainer_utils import get_last_checkpoint
18
- from dataclasses import dataclass, field
19
- import logging
20
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  @dataclass
24
- class TrainingOutputArguments:
25
 
26
  model_path: str = field(
27
  default='Xenova/sponsorblock-small',
@@ -34,28 +90,70 @@ class TrainingOutputArguments:
34
  output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
35
  'output_dir']
36
 
37
- def __post_init__(self):
38
- if self.model_path is not None:
39
- return
40
-
41
- if os.path.exists(self.output_dir):
42
- last_checkpoint = get_last_checkpoint(self.output_dir)
43
- if last_checkpoint is not None:
44
- self.model_path = last_checkpoint
45
- return
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- raise ModelLoadError(
48
- 'Unable to find model, explicitly set `--model_path`')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  @dataclass
52
- class PredictArguments(TrainingOutputArguments):
53
  video_id: str = field(
54
  default=None,
55
  metadata={
56
- 'help': 'Video to predict sponsorship segments for'}
57
  )
58
 
 
 
 
 
 
 
59
 
60
  _SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
61
  _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
@@ -297,31 +395,37 @@ def main():
297
  ))
298
  predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
299
 
300
- if predict_args.video_id is None:
301
- print('No video ID supplied. Use `--video_id`.')
302
  return
303
 
304
- model, tokenizer = get_model_tokenizer(predict_args.model_path, predict_args.cache_dir)
305
-
306
- predict_args.video_id = predict_args.video_id.strip()
307
- predictions = predict(predict_args.video_id, model, tokenizer,
308
- segmentation_args, classifier_args=classifier_args)
309
 
310
- video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
311
- if not predictions:
312
- print('No predictions found for', video_url)
313
- return
314
-
315
- print(len(predictions), 'predictions found for', video_url)
316
- for index, prediction in enumerate(predictions, start=1):
317
- print(f'Prediction #{index}:')
318
- print('Text: "',
319
- ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
320
- print('Time:', seconds_to_time(
321
- prediction['start']), '\u2192', seconds_to_time(prediction['end']))
322
- print('Category:', prediction.get('category'))
323
- if 'probability' in prediction:
324
- print('Probability:', prediction['probability'])
 
 
 
 
 
 
 
 
 
325
  print()
326
 
327
 
 
1
+ import itertools
2
+ import base64
3
+ import re
4
+ import requests
5
+ import json
6
+ from transformers import HfArgumentParser
7
+ from transformers.trainer_utils import get_last_checkpoint
8
+ from dataclasses import dataclass, field
9
+ import logging
10
+ import os
11
+ import itertools
12
  from utils import re_findall
13
  from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, OutputArguments, device, seconds_to_time
14
  from typing import Optional
 
22
  SegmentationArguments
23
  )
24
  import preprocess
25
+ from errors import PredictionException, TranscriptError, ModelLoadError, ClassifierLoadError
26
  from model import ModelArguments, get_classifier_vectorizer, get_model_tokenizer
27
+
28
+
29
+ # Public innertube key (b64 encoded so that it is not incorrectly flagged)
30
+ INNERTUBE_KEY = base64.b64decode(
31
+ b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
32
+
33
+ YT_CONTEXT = {
34
+ 'client': {
35
+ 'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
36
+ 'clientName': 'WEB',
37
+ 'clientVersion': '2.20211221.00.00',
38
+ }
39
+ }
40
+ _YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
41
+
42
+
43
+ def get_all_channel_vids(channel_id):
44
+ continuation = None
45
+ while True:
46
+ if continuation is None:
47
+ params = {'list': channel_id.replace('UC', 'UU', 1)}
48
+ response = requests.get(
49
+ 'https://www.youtube.com/playlist', params=params)
50
+ items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
51
+ 'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
52
+ else:
53
+ params = {'key': INNERTUBE_KEY}
54
+ data = {
55
+ 'context': YT_CONTEXT,
56
+ 'continuation': continuation
57
+ }
58
+ response = requests.post(
59
+ 'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
60
+ items = response.json()[
61
+ 'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
62
+
63
+ new_token = None
64
+ for vid in items:
65
+ info = vid.get('playlistVideoRenderer')
66
+ if info:
67
+ yield info['videoId']
68
+ continue
69
+
70
+ info = vid.get('continuationItemRenderer')
71
+ if info:
72
+ new_token = info['continuationEndpoint']['continuationCommand']['token']
73
+
74
+ if new_token is None:
75
+ break
76
+ continuation = new_token
77
 
78
 
79
  @dataclass
80
+ class InferenceArguments:
81
 
82
  model_path: str = field(
83
  default='Xenova/sponsorblock-small',
 
90
  output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
91
  'output_dir']
92
 
93
+ max_videos: Optional[int] = field(
94
+ default=None,
95
+ metadata={
96
+ 'help': 'The number of videos to test on'
97
+ }
98
+ )
99
+ start_index: int = field(default=None, metadata={
100
+ 'help': 'Video to start the evaluation at.'})
101
+ channel_id: Optional[str] = field(
102
+ default=None,
103
+ metadata={
104
+ 'help': 'Used to evaluate a channel'
105
+ }
106
+ )
107
+ video_ids: str = field(
108
+ default_factory=lambda: [],
109
+ metadata={
110
+ 'nargs': '+'
111
+ }
112
+ )
113
 
114
+ def __post_init__(self):
115
+ # Try to load model from latest checkpoint
116
+ if self.model_path is None:
117
+ if os.path.exists(self.output_dir):
118
+ last_checkpoint = get_last_checkpoint(self.output_dir)
119
+ if last_checkpoint is not None:
120
+ self.model_path = last_checkpoint
121
+ else:
122
+ raise ModelLoadError(
123
+ 'Unable to load model from checkpoint, explicitly set `--model_path`')
124
+ else:
125
+ raise ModelLoadError(
126
+ f'Unable to find model in {self.output_dir}, explicitly set `--model_path`')
127
+
128
+ if any(len(video_id) != 11 for video_id in self.video_ids):
129
+ raise PredictionException('Invalid video IDs (length not 11)')
130
+
131
+ if self.channel_id is not None:
132
+ start = self.start_index or 0
133
+ end = None if self.max_videos is None else start + self.max_videos
134
+
135
+ channel_video_ids = list(itertools.islice(get_all_channel_vids(
136
+ self.channel_id), start, end))
137
+ print('Found', len(channel_video_ids),
138
+ 'for channel', self.channel_id)
139
+
140
+ self.video_ids += channel_video_ids
141
 
142
 
143
  @dataclass
144
+ class PredictArguments(InferenceArguments):
145
  video_id: str = field(
146
  default=None,
147
  metadata={
148
+ 'help': 'Video to predict segments for'}
149
  )
150
 
151
+ def __post_init__(self):
152
+ if self.video_id is not None:
153
+ self.video_ids.append(self.video_id)
154
+
155
+ super().__post_init__()
156
+
157
 
158
  _SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
159
  _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
 
395
  ))
396
  predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
397
 
398
+ if not predict_args.video_ids:
399
+ print('No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.')
400
  return
401
 
402
+ model, tokenizer = get_model_tokenizer(
403
+ predict_args.model_path, predict_args.cache_dir)
 
 
 
404
 
405
+ for video_id in predict_args.video_ids:
406
+ video_id = video_id.strip()
407
+ try:
408
+ predictions = predict(video_id, model, tokenizer,
409
+ segmentation_args, classifier_args=classifier_args)
410
+ except TranscriptError:
411
+ print('No transcript available for', video_id, end='\n\n')
412
+ continue
413
+ video_url = f'https://www.youtube.com/watch?v={video_id}'
414
+ if not predictions:
415
+ print('No predictions found for', video_url, end='\n\n')
416
+ continue
417
+
418
+ print(len(predictions), 'predictions found for', video_url)
419
+ for index, prediction in enumerate(predictions, start=1):
420
+ print(f'Prediction #{index}:')
421
+ print('Text: "',
422
+ ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
423
+ print('Time:', seconds_to_time(
424
+ prediction['start']), '\u2192', seconds_to_time(prediction['end']))
425
+ print('Category:', prediction.get('category'))
426
+ if 'probability' in prediction:
427
+ print('Probability:', prediction['probability'])
428
+ print()
429
  print()
430
 
431