Joshua Lochner commited on
Commit
5fbdd3c
·
1 Parent(s): 5f40236

Add source code

Browse files
Files changed (9) hide show
  1. src/errors.py +13 -0
  2. src/evaluate.py +244 -0
  3. src/model.py +111 -0
  4. src/predict.py +278 -0
  5. src/preprocess.py +786 -0
  6. src/segment.py +142 -0
  7. src/shared.py +96 -0
  8. src/train.py +508 -0
  9. src/utils.py +86 -0
src/errors.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class SponsorBlockException(Exception):
2
+ """Base class for all sponsor block exceptions"""
3
+ pass
4
+
5
+
6
+ class PredictionException(SponsorBlockException):
7
+ """An exception was occurred while predicting sponsor segments"""
8
+ pass
9
+
10
+
11
+ class TranscriptError(SponsorBlockException):
12
+ """An exception was occurred while retrieving the video transcript"""
13
+ pass
src/evaluate.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ AutoModelForSeq2SeqLM,
4
+ AutoTokenizer,
5
+ HfArgumentParser
6
+ )
7
+ from preprocess import DatasetArguments, ProcessedArguments, get_words
8
+ from model import get_classifier_vectorizer
9
+ from shared import device
10
+ from predict import ClassifierArguments, PredictArguments, predict, filter_predictions
11
+ from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
12
+ import pandas as pd
13
+ from dataclasses import dataclass, field
14
+ from typing import Optional
15
+ from tqdm import tqdm
16
+ import json
17
+ import os
18
+ import random
19
+
20
+
21
+ @dataclass
22
+ class EvaluationArguments:
23
+ """
24
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
25
+ """
26
+ max_videos: Optional[int] = field(
27
+ default=None,
28
+ metadata={
29
+ 'help': 'The number of videos to test on'
30
+ }
31
+ )
32
+ model_path: Optional[str] = PredictArguments.__dataclass_fields__[
33
+ 'model_path']
34
+ data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
35
+ dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
36
+ 'validation_file']
37
+
38
+ output_file: Optional[str] = field(
39
+ default='metrics.csv',
40
+ metadata={
41
+ 'help': 'Save metrics to output file'
42
+ }
43
+ )
44
+
45
+
46
+ def jaccard(x1, x2, y1, y2):
47
+ # Calculate jaccard index
48
+ intersection = max(0, min(x2, y2)-max(x1, y1))
49
+ filled_union = max(x2, y2) - min(x1, y1)
50
+ return intersection/filled_union
51
+
52
+
53
+ def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
54
+ """Attach sponsor segments to closest prediction"""
55
+ for prediction in predictions:
56
+ prediction['best_overlap'] = 0
57
+ prediction['best_sponsorship'] = None
58
+
59
+ # Assign predictions to actual (labelled) sponsored segments
60
+ for sponsor_segment in sponsor_segments:
61
+ sponsor_segment['best_overlap'] = 0
62
+ sponsor_segment['best_prediction'] = None
63
+
64
+ for prediction in predictions:
65
+
66
+ j = jaccard(prediction['start'], prediction['end'],
67
+ sponsor_segment['start'], sponsor_segment['end'])
68
+ if sponsor_segment['best_overlap'] < j:
69
+ sponsor_segment['best_overlap'] = j
70
+ sponsor_segment['best_prediction'] = prediction
71
+
72
+ if prediction['best_overlap'] < j:
73
+ prediction['best_overlap'] = j
74
+ prediction['best_sponsorship'] = sponsor_segment
75
+
76
+ return sponsor_segments
77
+
78
+
79
+ def calculate_metrics(labelled_words, predictions):
80
+
81
+ metrics = {
82
+ 'true_positive': 0, # Is sponsor, predicted sponsor
83
+ # Is sponsor, predicted not sponsor (i.e., missed it - bad)
84
+ 'false_negative': 0,
85
+ # Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
86
+ 'false_positive': 0,
87
+ 'true_negative': 0, # Is not sponsor, predicted not sponsor
88
+ }
89
+
90
+ metrics['video_duration'] = word_end(
91
+ labelled_words[-1])-word_start(labelled_words[0])
92
+
93
+ for index, word in enumerate(labelled_words):
94
+ if index >= len(labelled_words) - 1:
95
+ continue
96
+
97
+ # TODO make sure words with manual transcripts
98
+ duration = labelled_words[index+1]['start'] - word['start']
99
+
100
+ predicted_sponsor = False
101
+ for p in predictions:
102
+ # Is in some prediction
103
+ if p['start'] <= word['start'] <= p['end']:
104
+ predicted_sponsor = True
105
+ break
106
+
107
+ if predicted_sponsor:
108
+ # total_positive_time += duration
109
+ if word['sponsor']: # Is actual sponsor
110
+ metrics['true_positive'] += duration
111
+ else:
112
+ metrics['false_positive'] += duration
113
+ else:
114
+ # total_negative_time += duration
115
+ if word['sponsor']: # Is actual sponsor
116
+ metrics['false_negative'] += duration
117
+ else:
118
+ metrics['true_negative'] += duration
119
+
120
+ # NOTE In cases where we encounter division by 0, we say that the value is 1
121
+ # https://stats.stackexchange.com/a/1775
122
+ # (Precision) TP+FP=0: means that all instances were predicted as negative
123
+ # (Recall) TP+FN=0: means that there were no positive cases in the input data
124
+
125
+ # The fraction of predictions our model got right
126
+ # Can simplify, but use full formula
127
+ z = metrics['true_positive'] + metrics['true_negative'] + \
128
+ metrics['false_positive'] + metrics['false_negative']
129
+ metrics['accuracy'] = (
130
+ (metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1
131
+
132
+ # What proportion of positive identifications was actually correct?
133
+ z = metrics['true_positive'] + metrics['false_positive']
134
+ metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1
135
+
136
+ # What proportion of actual positives was identified correctly?
137
+ z = metrics['true_positive'] + metrics['false_negative']
138
+ metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1
139
+
140
+ # https://deepai.org/machine-learning-glossary-and-terms/f-score
141
+
142
+ s = metrics['precision'] + metrics['recall']
143
+ metrics['f-score'] = (2 * (metrics['precision'] *
144
+ metrics['recall']) / s) if s > 0 else 0
145
+
146
+ return metrics
147
+
148
+
149
+ def main():
150
+ hf_parser = HfArgumentParser((
151
+ EvaluationArguments,
152
+ ProcessedArguments,
153
+ SegmentationArguments,
154
+ ClassifierArguments
155
+ ))
156
+
157
+ evaluation_args, processed_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
158
+
159
+ model = AutoModelForSeq2SeqLM.from_pretrained(evaluation_args.model_path)
160
+ model.to(device())
161
+
162
+ tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
163
+
164
+ dataset = load_dataset('json', data_files=os.path.join(
165
+ evaluation_args.data_dir, evaluation_args.dataset))['train']
166
+
167
+ video_ids = [row['video_id'] for row in dataset]
168
+ random.shuffle(video_ids) # TODO Make param
169
+
170
+ if evaluation_args.max_videos is not None:
171
+ video_ids = video_ids[:evaluation_args.max_videos]
172
+
173
+ # Load labelled data:
174
+ final_path = os.path.join(
175
+ processed_args.processed_dir, processed_args.processed_file)
176
+
177
+ with open(final_path) as fp:
178
+ final_data = json.load(fp)
179
+
180
+ classifier, vectorizer = get_classifier_vectorizer(classifier_args)
181
+
182
+ total_accuracy = 0
183
+ total_precision = 0
184
+ total_recall = 0
185
+ total_fscore = 0
186
+
187
+ count = 0
188
+ out_metrics = []
189
+
190
+ try:
191
+ with tqdm(video_ids) as progress:
192
+ for video_id in progress:
193
+ progress.set_description(f'Processing {video_id}')
194
+ sponsor_segments = final_data.get(video_id, [])
195
+
196
+ words = get_words(video_id)
197
+ if not words:
198
+ continue
199
+
200
+ count += 1
201
+
202
+ # Make predictions
203
+ predictions = predict(video_id, model, tokenizer,
204
+ segmentation_args, words)
205
+
206
+ # Filter predictions
207
+ predictions = filter_predictions(
208
+ predictions, classifier, vectorizer, classifier_args)
209
+
210
+ labelled_words = add_labels_to_words(words, sponsor_segments)
211
+ met = calculate_metrics(labelled_words, predictions)
212
+ met['video_id'] = video_id
213
+
214
+ out_metrics.append(met)
215
+
216
+ total_accuracy += met['accuracy']
217
+ total_precision += met['precision']
218
+ total_recall += met['recall']
219
+ total_fscore += met['f-score']
220
+
221
+ progress.set_postfix({
222
+ 'accuracy': total_accuracy/count,
223
+ 'precision': total_precision/count,
224
+ 'recall': total_recall/count,
225
+ 'f-score': total_fscore/count
226
+ })
227
+
228
+ labelled_predicted_segments = attach_predictions_to_sponsor_segments(
229
+ predictions, sponsor_segments)
230
+ for seg in labelled_predicted_segments:
231
+ if seg['best_prediction'] is None:
232
+ print('\nNo match found for', seg)
233
+
234
+ except KeyboardInterrupt:
235
+ pass
236
+
237
+ df = pd.DataFrame(out_metrics)
238
+
239
+ df.to_csv(evaluation_args.output_file)
240
+ print(df.mean())
241
+
242
+
243
+ if __name__ == '__main__':
244
+ main()
src/model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ from shared import CustomTokens
4
+ from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+
9
+ @dataclass
10
+ class ModelArguments:
11
+ """
12
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
13
+ """
14
+
15
+ model_name_or_path: str = field(
16
+ default='google/t5-v1_1-small', # t5-small
17
+ metadata={
18
+ 'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
19
+ )
20
+ # config_name: Optional[str] = field( # TODO remove?
21
+ # default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
22
+ # )
23
+ tokenizer_name: Optional[str] = field(
24
+ default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}
25
+ )
26
+ cache_dir: Optional[str] = field(
27
+ default=None,
28
+ metadata={
29
+ 'help': 'Where to store the pretrained models downloaded from huggingface.co'},
30
+ )
31
+ use_fast_tokenizer: bool = field( # TODO remove?
32
+ default=True,
33
+ metadata={
34
+ 'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'},
35
+ )
36
+ model_revision: str = field( # TODO remove?
37
+ default='main',
38
+ metadata={
39
+ 'help': 'The specific model version to use (can be a branch name, tag name or commit id).'},
40
+ )
41
+ use_auth_token: bool = field(
42
+ default=False,
43
+ metadata={
44
+ 'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
45
+ 'with private models).'
46
+ },
47
+ )
48
+ resize_position_embeddings: Optional[bool] = field(
49
+ default=None,
50
+ metadata={
51
+ 'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
52
+ },
53
+ )
54
+
55
+
56
+ def get_model(model_args, use_cache=True):
57
+ name = model_args.model_name_or_path
58
+ cached_path = f'models/{name}'
59
+
60
+ # Model created after tokenizer:
61
+ if use_cache and os.path.exists(os.path.join(cached_path, 'pytorch_model.bin')):
62
+ name = cached_path
63
+
64
+ config = AutoConfig.from_pretrained(
65
+ name,
66
+ cache_dir=model_args.cache_dir,
67
+ revision=model_args.model_revision,
68
+ use_auth_token=True if model_args.use_auth_token else None,
69
+ )
70
+
71
+ model = AutoModelForSeq2SeqLM.from_pretrained(
72
+ name,
73
+ from_tf='.ckpt' in name,
74
+ config=config,
75
+ cache_dir=model_args.cache_dir,
76
+ revision=model_args.model_revision,
77
+ use_auth_token=True if model_args.use_auth_token else None,
78
+ )
79
+
80
+ return model
81
+
82
+
83
+ def get_tokenizer(model_args, use_cache=True):
84
+ name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
85
+
86
+ cached_path = f'models/{name}'
87
+
88
+ if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
89
+ name = cached_path
90
+
91
+ tokenizer = AutoTokenizer.from_pretrained(
92
+ name,
93
+ cache_dir=model_args.cache_dir,
94
+ use_fast=model_args.use_fast_tokenizer,
95
+ revision=model_args.model_revision,
96
+ use_auth_token=True if model_args.use_auth_token else None,
97
+ )
98
+
99
+ CustomTokens.add_custom_tokens(tokenizer)
100
+
101
+ return tokenizer
102
+
103
+
104
+ def get_classifier_vectorizer(classifier_args):
105
+ with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'rb') as fp:
106
+ classifier = pickle.load(fp)
107
+
108
+ with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'rb') as fp:
109
+ vectorizer = pickle.load(fp)
110
+
111
+ return classifier, vectorizer
src/predict.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shared import OutputArguments
2
+ from typing import Optional
3
+ from segment import (
4
+ generate_segments,
5
+ extract_segment,
6
+ SAFETY_TOKENS,
7
+ CustomTokens,
8
+ word_start,
9
+ word_end,
10
+ SegmentationArguments
11
+ )
12
+ import preprocess
13
+ import re
14
+ from errors import TranscriptError
15
+ from model import get_classifier_vectorizer
16
+ from transformers import (
17
+ AutoModelForSeq2SeqLM,
18
+ AutoTokenizer
19
+ )
20
+ from dataclasses import dataclass, field
21
+ from transformers import HfArgumentParser
22
+ from shared import device
23
+ import logging
24
+ from transformers.trainer_utils import get_last_checkpoint
25
+
26
+
27
+ def seconds_to_time(seconds):
28
+ h, remainder = divmod(abs(int(seconds)), 3600)
29
+ m, s = divmod(remainder, 60)
30
+ return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}"
31
+
32
+
33
+ @dataclass
34
+ class PredictArguments:
35
+
36
+ video_id: str = field(
37
+ metadata={
38
+ 'help': 'Video to predict sponsorship segments for'}
39
+ )
40
+
41
+ model_path: str = field(
42
+ default=None,
43
+ metadata={
44
+ 'help': 'Path to pretrained model used for prediction'}
45
+ )
46
+
47
+ output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
48
+ 'output_dir']
49
+
50
+ def __post_init__(self):
51
+ if self.model_path is not None:
52
+ return
53
+
54
+ last_checkpoint = get_last_checkpoint(self.output_dir)
55
+ if last_checkpoint is not None:
56
+ self.model_path = last_checkpoint
57
+ else:
58
+ raise Exception(
59
+ 'Unable to find model, explicitly set `--model_path`')
60
+
61
+
62
+ SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
63
+
64
+ MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
65
+ MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
66
+
67
+
68
+ @dataclass
69
+ class ClassifierArguments:
70
+ classifier_dir: Optional[str] = field(
71
+ default='classifiers',
72
+ metadata={
73
+ 'help': 'The directory that contains the classifier and vectorizer.'
74
+ }
75
+ )
76
+
77
+ classifier_file: Optional[str] = field(
78
+ default='classifier.pickle',
79
+ metadata={
80
+ 'help': 'The name of the classifier'
81
+ }
82
+ )
83
+
84
+ vectorizer_file: Optional[str] = field(
85
+ default='vectorizer.pickle',
86
+ metadata={
87
+ 'help': 'The name of the vectorizer'
88
+ }
89
+ )
90
+
91
+ min_probability: float = field(
92
+ default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
93
+
94
+
95
+ def filter_predictions(predictions, classifier, vectorizer, classifier_args):
96
+ """Use classifier to filter predictions"""
97
+ if not predictions:
98
+ return predictions
99
+
100
+ transformed_segments = vectorizer.transform([
101
+ preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
102
+ for pred in predictions
103
+ ])
104
+ probabilities = classifier.predict_proba(transformed_segments)
105
+
106
+ filtered_predictions = []
107
+ for prediction, probability in zip(predictions, probabilities):
108
+ prediction['probability'] = probability[1]
109
+
110
+ if prediction['probability'] >= classifier_args.min_probability:
111
+ filtered_predictions.append(prediction)
112
+ # else:
113
+ # print('removing segment', prediction)
114
+
115
+ return filtered_predictions
116
+
117
+
118
+ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier_args=None):
119
+ # Allow words to be passed in so that we don't have to get the words if we already have them
120
+ if words is None:
121
+ words = preprocess.get_words(video_id)
122
+ if not words:
123
+ raise TranscriptError('Unable to retrieve transcript')
124
+
125
+ segments = generate_segments(
126
+ words,
127
+ tokenizer,
128
+ segmentation_args
129
+ )
130
+
131
+ predictions = segments_to_prediction_times(segments, model, tokenizer)
132
+
133
+ # Add words back to time_ranges
134
+ for prediction in predictions:
135
+ # Stores words in the range
136
+ prediction['words'] = extract_segment(
137
+ words, prediction['start'], prediction['end'])
138
+
139
+ if classifier_args is not None:
140
+ classifier, vectorizer = get_classifier_vectorizer(classifier_args)
141
+ predictions = filter_predictions(
142
+ predictions, classifier, vectorizer, classifier_args)
143
+
144
+ return predictions
145
+
146
+
147
+ def greedy_match(list, sublist):
148
+ # Return index and length of longest matching sublist
149
+
150
+ best_i = -1
151
+ best_j = -1
152
+ best_k = 0
153
+
154
+ for i in range(len(list)): # Start position in main list
155
+ for j in range(len(sublist)): # Start position in sublist
156
+ for k in range(len(sublist)-j, 0, -1): # Width of sublist window
157
+ if k > best_k and list[i:i+k] == sublist[j:j+k]:
158
+ best_i, best_j, best_k = i, j, k
159
+ break # Since window size decreases
160
+
161
+ return best_i, best_j, best_k
162
+
163
+
164
+ DEFAULT_TOKEN_PREFIX = 'summarize: '
165
+
166
+
167
+ def predict_sponsor_text(text, model, tokenizer):
168
+ """Given a body of text, predict the words which are part of the sponsor"""
169
+ input_ids = tokenizer(
170
+ f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids
171
+
172
+ # Can't be longer than input length + SAFETY_TOKENS or model input dim
173
+ max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
174
+ outputs = model.generate(input_ids, max_length=max_out_len)
175
+
176
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
177
+
178
+
179
+ def predict_sponsor_matches(text, model, tokenizer):
180
+ sponsorship_text = predict_sponsor_text(text, model, tokenizer)
181
+ if CustomTokens.NO_SPONSOR.value in sponsorship_text:
182
+ return []
183
+
184
+ return re.findall(SPONSOR_MATCH_RE, sponsorship_text)
185
+
186
+
187
+ def segments_to_prediction_times(segments, model, tokenizer):
188
+ predicted_time_ranges = []
189
+
190
+ # TODO pass to model simultaneously, not in for loop
191
+ # use 2d array for input ids
192
+ for segment in segments:
193
+ cleaned_batch = [preprocess.clean_text(
194
+ word['text']) for word in segment]
195
+ batch_text = ' '.join(cleaned_batch)
196
+
197
+ matches = predict_sponsor_matches(batch_text, model, tokenizer)
198
+
199
+ for match in matches:
200
+ matched_text = match.split()
201
+ # TODO skip if too short
202
+
203
+ i1, j1, k1 = greedy_match(
204
+ cleaned_batch, matched_text[:MATCH_WINDOW])
205
+ i2, j2, k2 = greedy_match(
206
+ cleaned_batch, matched_text[-MATCH_WINDOW:])
207
+
208
+ extracted_words = segment[i1:i2+k2]
209
+
210
+ if not extracted_words:
211
+ continue
212
+
213
+ predicted_time_ranges.append({
214
+ 'start': word_start(extracted_words[0]),
215
+ 'end': word_end(extracted_words[-1])
216
+ })
217
+
218
+ # Necessary to sort matches by start time
219
+ predicted_time_ranges.sort(key=word_start)
220
+
221
+ # Merge overlapping predictions and sponsorships that are close together
222
+ # Caused by model having max input size
223
+ last_end_time = -1
224
+ final_predicted_time_ranges = []
225
+ for range in predicted_time_ranges:
226
+ start_time = range['start']
227
+ end_time = range['end']
228
+
229
+ if (start_time <= last_end_time <= end_time) or (last_end_time != -1 and start_time - last_end_time <= MERGE_TIME_WITHIN):
230
+ # Ending time of last segment is in this segment, so we extend last prediction range
231
+ final_predicted_time_ranges[-1]['end'] = end_time
232
+
233
+ else: # No overlap, is a new prediction
234
+ final_predicted_time_ranges.append({
235
+ 'start': start_time,
236
+ 'end': end_time,
237
+ })
238
+
239
+ last_end_time = end_time
240
+
241
+ return final_predicted_time_ranges
242
+
243
+
244
+ def main():
245
+ # Test on unseen data
246
+ logging.getLogger().setLevel(logging.DEBUG)
247
+
248
+ hf_parser = HfArgumentParser((
249
+ PredictArguments,
250
+ SegmentationArguments,
251
+ ClassifierArguments
252
+ ))
253
+ predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
254
+
255
+ model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
256
+ model.to(device())
257
+
258
+ tokenizer = AutoTokenizer.from_pretrained(predict_args.model_path)
259
+
260
+ predict_args.video_id = predict_args.video_id.strip()
261
+ print(
262
+ f'Predicting for https://www.youtube.com/watch?v={predict_args.video_id}')
263
+ predictions = predict(predict_args.video_id, model, tokenizer,
264
+ segmentation_args, classifier_args=classifier_args)
265
+
266
+ for prediction in predictions:
267
+ print(' '.join([w['text'] for w in prediction['words']]))
268
+ print(seconds_to_time(prediction['start']),
269
+ '-->', seconds_to_time(prediction['end']))
270
+ print(prediction['start'], '-->', prediction['end'])
271
+ print(prediction['probability'])
272
+ print()
273
+
274
+ print()
275
+
276
+
277
+ if __name__ == '__main__':
278
+ main()
src/preprocess.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import Optional
3
+ from datasets import load_dataset
4
+ from model import ModelArguments
5
+ import segment
6
+ from tqdm import tqdm
7
+ from dataclasses import dataclass, field
8
+ from transformers import HfArgumentParser
9
+ from shared import GeneralArguments, CustomTokens
10
+ import csv
11
+ import re
12
+ import random
13
+ import logging
14
+ from youtube_transcript_api import YouTubeTranscriptApi
15
+ from youtube_transcript_api._errors import CouldNotRetrieveTranscript, YouTubeRequestFailed
16
+ import os
17
+ import json
18
+ import time
19
+ import requests
20
+ from utils import InterruptibleThreadPool, Job
21
+
22
+
23
+ def find(s, ch):
24
+ return [i for i, ltr in enumerate(s) if ltr == ch]
25
+
26
+
27
+ def wordify(transcript):
28
+ """Try to replicate format for automatically generated transcripts"""
29
+ words = []
30
+
31
+ for line_index, line in enumerate(transcript):
32
+ text = line['text'].replace('\n', ' ').strip()
33
+ if not text:
34
+ continue
35
+
36
+ start = line['start']
37
+ next_start = transcript[line_index +
38
+ 1]['start'] if line_index < len(transcript) - 1 else float('inf')
39
+ end = min(start + line['duration'], next_start)
40
+ duration = end - start
41
+
42
+ indices = find(text, ' ') + [len(text)]
43
+ start_index = 0
44
+ for i in range(len(indices)):
45
+ word = text[start_index:indices[i]].strip()
46
+ if not word:
47
+ continue # Skip empty words (e.g., \n)
48
+ percentage = start_index/indices[-1]
49
+
50
+ w_duration = len(word)/indices[-1] * duration
51
+
52
+ w_start = start + percentage * duration
53
+
54
+ words.append({
55
+ 'start': round(w_start, 5),
56
+ 'duration': round(w_duration, 5),
57
+ 'end': round(w_start + w_duration, 5),
58
+ 'text': word,
59
+ })
60
+
61
+ start_index = indices[i] + 1
62
+
63
+ return words
64
+
65
+
66
+ def get_manual_words(transcript_list):
67
+ transcript = transcript_list.find_manually_created_transcript(
68
+ ['en-GB', 'en-US', 'en']).fetch()
69
+ return wordify(transcript)
70
+
71
+
72
+ def get_auto_words(transcript_list):
73
+ words = []
74
+ transcript = transcript_list.find_generated_transcript(['en'])
75
+ url = transcript._url + '&fmt=json3'
76
+ info = transcript._http_client.get(url)
77
+
78
+ for event in info.json()['events']:
79
+ start_ms = event.get('tStartMs', 0)
80
+
81
+ for word in event.get('segs') or []:
82
+ offset_ms = word.get('tOffsetMs', 0)
83
+
84
+ texts = word['utf8'].replace(
85
+ CustomTokens.PROFANITY_RAW.value, CustomTokens.PROFANITY_CONVERTED.value
86
+ ).strip().split()
87
+
88
+ for text in texts:
89
+ words.append({
90
+ 'start': (start_ms + offset_ms)/1000,
91
+ 'text': text
92
+ })
93
+
94
+ return words
95
+
96
+
97
+ def get_words(video_id, process=True, fallback=False, transcript_type='auto'):
98
+ """Get parsed video transcript with caching system
99
+ returns None if not processed yet and process is False
100
+ """
101
+ get_manual_if_fail = fallback and transcript_type == 'auto'
102
+ transcript_path = os.path.join(
103
+ 'transcripts', transcript_type, f'{video_id}.json')
104
+ words = []
105
+ try:
106
+ if os.path.exists(transcript_path):
107
+ with open(transcript_path) as fp:
108
+ wds = json.load(fp)
109
+
110
+ if not wds and get_manual_if_fail:
111
+ return get_words(video_id, process, fallback, 'manual')
112
+ return wds
113
+
114
+ elif not process:
115
+ return None
116
+
117
+ transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
118
+
119
+ if transcript_type == 'manual':
120
+ words = get_manual_words(transcript_list)
121
+ else:
122
+ words = get_auto_words(transcript_list)
123
+
124
+ except YouTubeRequestFailed as e:
125
+ print(e)
126
+ time.sleep(30) # Timeout
127
+ return get_words(video_id, process, fallback, transcript_type)
128
+
129
+ except CouldNotRetrieveTranscript:
130
+ if get_manual_if_fail:
131
+ print('fallback')
132
+ return get_words(video_id, process, fallback, 'manual')
133
+
134
+ except json.decoder.JSONDecodeError:
135
+ # Warning, unable to parse JSON
136
+ pass
137
+
138
+ with open(transcript_path, 'w') as fp:
139
+ json.dump(words, fp)
140
+
141
+ return words
142
+
143
+
144
+ # TODO make min_sponsor_segment_length param
145
+ def extract_sponsors(words, min_sponsor_segment_length=5):
146
+ if len(words) < min_sponsor_segment_length:
147
+ return [] # Force short phrases to not be sponsors
148
+
149
+ paragraphs = []
150
+ current = []
151
+ for word in words:
152
+ if not word.get('sponsor') and not current:
153
+ continue
154
+
155
+ if word['sponsor']:
156
+ current.append(word['text'])
157
+ else:
158
+ paragraphs.append(current)
159
+ current = []
160
+ if current:
161
+ paragraphs.append(current)
162
+
163
+ # Remove all too short:
164
+ paragraphs = list(filter(lambda x: len(
165
+ x) >= min_sponsor_segment_length, paragraphs))
166
+
167
+ return paragraphs
168
+
169
+
170
+ def clean_text(text):
171
+
172
+ # Replace impossibly long words with a special token
173
+ # Usually the result of incorrect labelling
174
+ text = re.sub(r'\w{64,}', CustomTokens.LONG_WORD.value, text)
175
+
176
+ SHORT_HYPHENATED_REGEX = r'\w{1,2}(?:-\w{1,2}){3,}(?:-?\w*)'
177
+
178
+ # Replace hyphenated URLs with special token
179
+ # For some reason, youtube sometimes transcribes urls in this form:
180
+ # 'b-a-b-b-e-l-dot-com', 'g-e-t-r-o-m-a-n-com'
181
+ # not 'e-commerce'
182
+ text = re.sub(f'{SHORT_HYPHENATED_REGEX}(?:com|org|net)',
183
+ CustomTokens.HYPHENATED_URL.value, text)
184
+
185
+ # Replace short+hyphenated text with a special token. Of the form:
186
+ # 'i-i-i-i-i-i-i-i-i-i-i-i', 'b-u-m-f-u-z-z-l-e', 'v-e-r-i-t-a-s-i-u-m', 'do-do-do-do-do'
187
+ text = re.sub(SHORT_HYPHENATED_REGEX,
188
+ CustomTokens.SHORT_HYPHENATED.value, text)
189
+
190
+ # Replace URLs with URL_TOKEN
191
+ URL_REGEX = r'(?:(?:http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.(?:[a-zA-Z]){2,6}(?:[a-zA-Z0-9\.\&\/\?\:@\-_=#%])*'
192
+ text = re.sub(URL_REGEX, CustomTokens.URL.value, text)
193
+
194
+ NUM_REGEX = r'(?:\d+,)*(?:\d*[.])?\d+'
195
+
196
+ # Encode specific numeric words
197
+ # Of the form: 12%, 12.34%
198
+ # Usually included in sponsorships
199
+ text = re.sub(f'{NUM_REGEX}%',
200
+ CustomTokens.NUMBER_PERCENTAGE.value, text)
201
+
202
+ # Normal numbers, should not have an effect on sponsorship
203
+ text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
204
+
205
+ # Replace profanity with special token
206
+ text = text.replace(CustomTokens.PROFANITY_RAW.value,
207
+ CustomTokens.PROFANITY.value)
208
+ text = text.replace(CustomTokens.PROFANITY_CONVERTED.value,
209
+ CustomTokens.PROFANITY.value)
210
+
211
+ return text.strip()
212
+
213
+
214
+ def remove_duplicate_sponsor_segments(sponsor_segments):
215
+ """Choose the best sponsor segment if overlapping with others"""
216
+
217
+ # Algorithm based on SponsorBlock algorithm
218
+ # Find sponsors that are overlapping
219
+ similar = []
220
+ for i in sponsor_segments:
221
+ for j in sponsor_segments:
222
+ # Since we do pairwise, we only check one direction
223
+ if (j['start'] >= i['start'] and j['start'] <= i['end']):
224
+ similar.append([i, j])
225
+
226
+ # Within each group, choose the segment with the most votes.
227
+ processed = []
228
+ best = []
229
+ for i in similar:
230
+ if i in processed:
231
+ continue
232
+ group = i
233
+ for j in similar:
234
+ if j[0] in group or j[1] in group: # If either in, append both
235
+ group.append(j[0])
236
+ group.append(j[1])
237
+ processed.append(j)
238
+
239
+ best.append(max(group, key=lambda item: (
240
+ item['votes'], item['reputation'], item['views'])))
241
+
242
+ return best
243
+
244
+
245
+ @dataclass
246
+ class PreprocessArguments:
247
+ """
248
+ Arguments pertaining to what data we are going to preprocess.
249
+ """
250
+ update_database: bool = field(
251
+ default=False, metadata={'help': 'Download the raw database.'}
252
+ )
253
+
254
+ do_create: bool = field(
255
+ default=False, metadata={'help': 'Merge sponsor segments into single file'}
256
+ )
257
+ min_votes: int = field(
258
+ default=0, metadata={'help': 'Minimum number of votes'})
259
+ # Downvotes will make this negative.
260
+ # 1 = At least one positive vote
261
+
262
+ do_transcribe: bool = field(
263
+ default=False, metadata={'help': 'Get transcripts for videos'}
264
+ )
265
+ num_jobs: int = field(
266
+ default=4, metadata={'help': 'Number of transcripts to download in parallel'})
267
+
268
+ overwrite: bool = field(
269
+ default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
270
+ )
271
+
272
+ do_generate: bool = field(
273
+ default=False, metadata={'help': 'Generate labelled data.'}
274
+ )
275
+
276
+ do_split: bool = field(
277
+ default=False, metadata={'help': 'Generate training, testing and validation data.'}
278
+ )
279
+ percentage_positive: float = field(
280
+ default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
281
+
282
+ train_split: float = field(
283
+ default=0.9, metadata={'help': 'Ratio of training data. Value between 0 and 1.'})
284
+
285
+ # TODO play around with ratios? lower test/validation split?
286
+ test_split: float = field(
287
+ default=0.05, metadata={'help': 'Ratio of testing data. Value between 0 and 1.'})
288
+ valid_split: float = field(
289
+ default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
290
+
291
+ skip_videos: int = field(default=None, metadata={
292
+ 'help': 'Number of videos to skip. Set this to the latest video index to append to the current file'})
293
+
294
+ max_videos: int = field(default=None, metadata={
295
+ 'help': 'Maximum number of videos to preprocess.'})
296
+
297
+ max_segments: int = field(default=None, metadata={
298
+ 'help': 'Maximum number of segments to produce to preprocess.'})
299
+
300
+ raw_data_dir: Optional[str] = field(
301
+ default='raw',
302
+ metadata={
303
+ 'help': 'Raw data directory'
304
+ },
305
+ )
306
+ raw_data_file: Optional[str] = field(
307
+ default='sponsorTimes.csv',
308
+ metadata={
309
+ 'help': 'Raw data file'
310
+ },
311
+ )
312
+
313
+ min_wps: float = field(
314
+ default=0.4, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'})
315
+ # 0.1 ~ 1%
316
+ # 0.4 ~ 2.5%
317
+ # 0.9 ~ 5%
318
+
319
+
320
+ # Mirrors for database
321
+ MIRRORS = [
322
+ 'https://sponsor.ajay.app/database/sponsorTimes.csv', # Latest
323
+ 'https://sb-mirror.mchang.xyz/sponsorTimes.csv', # 5 minute delay
324
+ 'https://sb.ltn.fi/database/sponsorTimes.csv', # 5 minute delay
325
+ ]
326
+ # TODO only download latest (updates/changes)
327
+
328
+
329
+ def download_file(url, filename):
330
+ """
331
+ Helper method handling downloading large files from `url` to `filename`.
332
+
333
+ Adapted from https://stackoverflow.com/a/42071418
334
+ """
335
+ chunk_size = 1024
336
+ r = requests.get(url, stream=True)
337
+ total_bytes = int(r.headers['Content-Length'])
338
+ with open(filename, 'wb') as f, tqdm(unit='B', total=total_bytes) as progress:
339
+ for chunk in r.iter_content(chunk_size=chunk_size):
340
+ if chunk: # filter out keep-alive new chunks
341
+ progress.update(len(chunk))
342
+ f.write(chunk)
343
+
344
+ return total_bytes == os.path.getsize(filename)
345
+
346
+
347
+ @dataclass
348
+ class ProcessedArguments:
349
+ processed_dir: Optional[str] = field(
350
+ default='processed',
351
+ metadata={
352
+ 'help': 'Processed data directory'
353
+ },
354
+ )
355
+ processed_file: Optional[str] = field(
356
+ default='final.json',
357
+ metadata={
358
+ 'help': 'Processed data file'
359
+ },
360
+ )
361
+
362
+
363
+ def load_datasets(dataset_args):
364
+ print('Reading datasets')
365
+ data_files = {}
366
+
367
+ if dataset_args.train_file is not None:
368
+ data_files['train'] = os.path.join(
369
+ dataset_args.data_dir, dataset_args.train_file)
370
+ if dataset_args.validation_file is not None:
371
+ data_files['validation'] = os.path.join(
372
+ dataset_args.data_dir, dataset_args.validation_file)
373
+ if dataset_args.test_file is not None:
374
+ data_files['test'] = os.path.join(
375
+ dataset_args.data_dir, dataset_args.test_file)
376
+
377
+ return load_dataset('json', data_files=data_files)
378
+
379
+
380
+ @dataclass
381
+ class DatasetArguments:
382
+ data_dir: Optional[str] = field(
383
+ default='data',
384
+ metadata={
385
+ 'help': 'The directory which stores train, test and/or validation data.'
386
+ },
387
+ )
388
+
389
+ train_file: Optional[str] = field(
390
+ default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
391
+ )
392
+ validation_file: Optional[str] = field(
393
+ default='valid.json',
394
+ metadata={
395
+ 'help': 'An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines file).'
396
+ },
397
+ )
398
+ test_file: Optional[str] = field(
399
+ default='test.json',
400
+ metadata={
401
+ 'help': 'An optional input test data file to evaluate the metrics (rouge) on (a jsonlines file).'
402
+ },
403
+ )
404
+ excess_file: Optional[str] = field(
405
+ default='excess.json',
406
+ metadata={
407
+ 'help': 'The excess segments left after the split'
408
+ },
409
+ )
410
+
411
+ overwrite_cache: bool = field(
412
+ default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
413
+ )
414
+
415
+ positive_file: Optional[str] = field(
416
+ default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
417
+ )
418
+ negative_file: Optional[str] = field(
419
+ default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
420
+ )
421
+
422
+ def __post_init__(self):
423
+ # TODO check if train/validation datasets exist
424
+ if self.train_file is None and self.validation_file is None:
425
+ raise ValueError(
426
+ 'Need either a dataset name or a training/validation file.')
427
+
428
+
429
+ def main():
430
+ # Responsible for getting transcrips using youtube_transcript_api,
431
+ # then labelling it according to SponsorBlock's API
432
+
433
+ logging.getLogger().setLevel(logging.INFO) # TODO make param
434
+
435
+ # Generate final.json from sponsorTimes.csv
436
+ hf_parser = HfArgumentParser((
437
+ PreprocessArguments,
438
+ ProcessedArguments,
439
+ DatasetArguments,
440
+ segment.SegmentationArguments,
441
+ ModelArguments,
442
+ GeneralArguments
443
+ ))
444
+ preprocess_args, processed_args, dataset_args, segmentation_args, model_args, _ = hf_parser.parse_args_into_dataclasses()
445
+
446
+ raw_dataset_path = os.path.join(
447
+ preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
448
+
449
+ def get_rows():
450
+ with open(raw_dataset_path, newline='') as csvfile:
451
+ reader = csv.DictReader(csvfile)
452
+ for line in reader:
453
+ if line['service'] != 'YouTube':
454
+ continue
455
+
456
+ # TODO add support for other categories and action types?
457
+ if line['category'] != 'sponsor':
458
+ continue
459
+ if line['actionType'] != 'skip':
460
+ continue
461
+
462
+ # Ignore hidden items
463
+ if line['hidden'] == '1' or line['shadowHidden'] == '1':
464
+ continue
465
+
466
+ if len(line['videoID']) != 11:
467
+ continue # Invalid youtube video ID
468
+
469
+ # Skip those that aren't highly voted
470
+ line['votes'] = int(line['votes'])
471
+ # incorrect_votes = int(line['incorrectVotes'])
472
+
473
+ if line['votes'] < preprocess_args.min_votes:
474
+ continue
475
+
476
+ yield line
477
+
478
+ if preprocess_args.update_database:
479
+ print('Updating database')
480
+ for mirror in MIRRORS:
481
+ print('Downloading from', mirror)
482
+ if download_file(mirror, raw_dataset_path):
483
+ break
484
+ print('Failed, trying next')
485
+
486
+ # 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
487
+ # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
488
+ # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
489
+ data_rows = None
490
+ if preprocess_args.do_transcribe:
491
+ print('Collecting videos')
492
+ video_ids = set()
493
+ data_rows = get_rows()
494
+ for row in data_rows:
495
+ video_ids.add(row['videoID'])
496
+
497
+ print('Start transcribing')
498
+ with tqdm(total=len(video_ids)) as progress:
499
+ def on_job_complete(job):
500
+ progress.set_description(f'Processed {job.video_id}')
501
+ progress.update()
502
+
503
+ pool = InterruptibleThreadPool(
504
+ preprocess_args.num_jobs, on_job_complete=on_job_complete)
505
+
506
+ print('Adding jobs to pool')
507
+ for video_id in video_ids:
508
+ job = Job(get_words, video_id)
509
+ job.video_id = video_id
510
+ pool.add_job(job)
511
+
512
+ print('Start processing')
513
+ pool.run()
514
+
515
+ print('Finished transcribing')
516
+
517
+ final_path = os.path.join(
518
+ processed_args.processed_dir, processed_args.processed_file)
519
+
520
+ if os.path.exists(final_path) and not preprocess_args.do_create:
521
+ logging.info(f'{final_path} exists, opening file')
522
+ with open(final_path) as fp:
523
+ final_data = json.load(fp)
524
+ else:
525
+ print('Create final data')
526
+
527
+ final_data = {}
528
+
529
+ if data_rows is None:
530
+ data_rows = get_rows()
531
+
532
+ # TODO add progress bar
533
+ # TODO parallelise?
534
+ for line in data_rows:
535
+ video_id = line['videoID']
536
+
537
+ if video_id not in final_data:
538
+ final_data[video_id] = []
539
+
540
+ segment_start = float(line['startTime'])
541
+ segment_end = float(line['endTime'])
542
+
543
+ video_words = get_words(video_id, process=True)
544
+ segment_words = segment.extract_segment(
545
+ video_words, segment_start, segment_end)
546
+
547
+ if len(segment_words) <= 1:
548
+ continue # Useless to add segment since no words
549
+
550
+ # duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0])
551
+ duration = segment_end - segment_start
552
+ wps = len(segment_words)/duration if duration > 0 else 0
553
+
554
+ if wps < preprocess_args.min_wps:
555
+ print('bad segment in', video_id, '| wps =', wps)
556
+ continue
557
+
558
+ final_data[video_id].append({
559
+ 'start': segment_start,
560
+ 'end': segment_end,
561
+ 'votes': line['votes'],
562
+ 'locked': line['locked'] == '1',
563
+ 'views': line['views'],
564
+ 'reputation': line['reputation'],
565
+ 'category': line['category'],
566
+ 'action': line['actionType'],
567
+ 'uuid': line['UUID'],
568
+ })
569
+
570
+ # Remove duplicate sponsor segments by choosing best (most votes)
571
+ for key in final_data:
572
+ final_data[key] = remove_duplicate_sponsor_segments(
573
+ final_data[key])
574
+
575
+ # Save data
576
+ with open(final_path, 'w') as fp:
577
+ json.dump(final_data, fp)
578
+
579
+ # final_data = preprocess(
580
+ # raw_dataset_path, final_path, preprocess_args.min_votes)
581
+ # # TODO save metadata in final.json?
582
+
583
+ logging.info(f'Found {len(final_data)} videos')
584
+
585
+ # TODO shuffle final_data
586
+
587
+ # if not os.path.exists(excess_path) or preprocess_args.overwrite
588
+ # TODO use overwrite param
589
+
590
+ os.makedirs(dataset_args.data_dir, exist_ok=True)
591
+
592
+ positive_file = os.path.join(
593
+ dataset_args.data_dir, dataset_args.positive_file)
594
+ negative_file = os.path.join(
595
+ dataset_args.data_dir, dataset_args.negative_file)
596
+
597
+ if preprocess_args.do_generate:
598
+ print('Generating')
599
+ from model import get_tokenizer
600
+
601
+ # max_videos=preprocess_args.max_videos,
602
+ # max_segments=preprocess_args.max_segments,
603
+ # , max_videos, max_segments
604
+
605
+ tokenizer = get_tokenizer(model_args)
606
+
607
+ count_videos = 0
608
+ count_segments = 0 # TODO
609
+
610
+ write_mode = 'w' if preprocess_args.overwrite else 'a'
611
+
612
+ get_all = preprocess_args.max_videos is None
613
+ if get_all:
614
+ total = len(final_data)
615
+ else:
616
+ total = preprocess_args.max_videos
617
+
618
+ index = 0
619
+ data = final_data.items()
620
+ if preprocess_args.skip_videos is not None:
621
+ print('Skipping first', preprocess_args.skip_videos, 'videos')
622
+ data = itertools.islice(data, preprocess_args.skip_videos, None)
623
+ index = preprocess_args.skip_videos
624
+
625
+ if get_all:
626
+ total = max(0, total - preprocess_args.skip_videos)
627
+ else:
628
+ total = min(len(final_data) -
629
+ preprocess_args.skip_videos, total)
630
+
631
+ with open(positive_file, write_mode, encoding='utf-8') as positive, \
632
+ open(negative_file, write_mode, encoding='utf-8') as negative, \
633
+ tqdm(total=total) as progress:
634
+
635
+ for video_id, sponsor_segments in data:
636
+ index += 1 # TODO FIX index + incrementing
637
+ progress.set_description(f'Processing {video_id}')
638
+
639
+ if get_all:
640
+ progress.update()
641
+ elif count_videos >= preprocess_args.max_videos:
642
+ break
643
+
644
+ words = get_words(video_id, False)
645
+ if not words:
646
+ continue
647
+
648
+ num_words = len(words)
649
+ if num_words <= 1:
650
+ continue
651
+
652
+ # TODO only count words that aren't [Music], [Applause], etc.
653
+
654
+ segments = segment.generate_labelled_segments(
655
+ words, tokenizer, segmentation_args, sponsor_segments)
656
+
657
+ if not segments:
658
+ continue
659
+
660
+ count_videos += 1
661
+ if not get_all:
662
+ progress.update()
663
+
664
+ for seg in segments:
665
+
666
+ segment_text = ' '.join((x['text'] for x in seg))
667
+
668
+ extracted_text = ''
669
+ for p in extract_sponsors(seg):
670
+ p_text = ' '.join(p)
671
+ extracted_text += f'{CustomTokens.START_SPONSOR.value} {p_text} {CustomTokens.END_SPONSOR.value}. '
672
+
673
+ duration = segment.word_end(
674
+ seg[-1]) - segment.word_start(seg[0])
675
+ wps = len(seg)/duration if duration > 0 else 0
676
+ # Ignore segments with "not enough words" in the transcript
677
+ if wps < preprocess_args.min_wps:
678
+ continue
679
+
680
+ d = {
681
+ 'video_index': index,
682
+ 'video_id': video_id,
683
+ 'text': clean_text(segment_text),
684
+ 'words_per_second': wps,
685
+ }
686
+
687
+ d['sponsor'] = bool(extracted_text)
688
+ d['extracted'] = clean_text(
689
+ extracted_text) if d['sponsor'] else CustomTokens.NO_SPONSOR.value
690
+
691
+ print(json.dumps(d), file=(
692
+ positive if d['sponsor'] else negative))
693
+
694
+ if preprocess_args.do_split:
695
+ print('Splitting')
696
+ print('Read files')
697
+
698
+ with open(positive_file, encoding='utf-8') as positive:
699
+ sponsors = positive.readlines()
700
+
701
+ with open(negative_file, encoding='utf-8') as negative:
702
+ non_sponsors = negative.readlines()
703
+
704
+ print('Shuffle')
705
+ random.shuffle(sponsors)
706
+ random.shuffle(non_sponsors)
707
+
708
+ print('Calculate ratios')
709
+ # Ensure correct ratio of positive to negative segments
710
+ percentage_negative = 1 - preprocess_args.percentage_positive
711
+
712
+ if preprocess_args.percentage_positive * len(sponsors) > len(non_sponsors):
713
+ # Negative is limiting
714
+ z = int(preprocess_args.percentage_positive /
715
+ percentage_negative * len(non_sponsors))
716
+
717
+ excess = sponsors[z:]
718
+ sponsors = sponsors[:z]
719
+
720
+ else:
721
+ # Positive is limiting
722
+ z = int(percentage_negative /
723
+ preprocess_args.percentage_positive * len(sponsors))
724
+
725
+ excess = non_sponsors[z:]
726
+ non_sponsors = non_sponsors[:z]
727
+
728
+ print('Join')
729
+ all_labelled_segments = sponsors + non_sponsors
730
+
731
+ random.shuffle(all_labelled_segments)
732
+
733
+ print('Split')
734
+ ratios = [preprocess_args.train_split,
735
+ preprocess_args.test_split,
736
+ preprocess_args.valid_split]
737
+
738
+ train_data, test_data, valid_data = split(
739
+ all_labelled_segments, ratios)
740
+
741
+ splits = {
742
+ dataset_args.train_file: train_data,
743
+ dataset_args.test_file: test_data,
744
+ dataset_args.validation_file: valid_data
745
+ }
746
+
747
+ # Output training, testing and validation data
748
+ for name, items in splits.items():
749
+ outfile = os.path.join(dataset_args.data_dir, name)
750
+ if not os.path.exists(outfile) or preprocess_args.overwrite:
751
+ with open(outfile, 'w', encoding='utf-8') as fp:
752
+ fp.writelines(items)
753
+ else:
754
+ print('Skipping', name)
755
+
756
+ print('Write')
757
+ # Save excess items
758
+ excess_path = os.path.join(
759
+ dataset_args.data_dir, dataset_args.excess_file)
760
+ if not os.path.exists(excess_path) or preprocess_args.overwrite:
761
+ with open(excess_path, 'w', encoding='utf-8') as fp:
762
+ fp.writelines(excess)
763
+ else:
764
+ print('Skipping', dataset_args.excess_file)
765
+
766
+ print('Finished splitting:', len(sponsors),
767
+ 'sponsors,', len(non_sponsors), 'non sponsors')
768
+
769
+
770
+ def split(arr, ratios):
771
+ """Split array according to ratios. Sum of ratios should be less than 1"""
772
+
773
+ to_return = []
774
+
775
+ cumulative_sum = 0
776
+ for r in ratios:
777
+ current = cumulative_sum
778
+
779
+ cumulative_sum += r * len(arr)
780
+ to_return.append(arr[int(current):int(cumulative_sum)])
781
+
782
+ return to_return
783
+
784
+
785
+ if __name__ == '__main__':
786
+ main()
src/segment.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import preprocess
2
+ from shared import CustomTokens
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class SegmentationArguments:
8
+ pause_threshold: int = field(default=2, metadata={
9
+ 'help': 'When the time between words is greater than pause threshold, force into a new segment'})
10
+
11
+
12
+ # WORDS TO ALWAYS HAVE ON THEIR OWN
13
+ # always_split_re = re.compile(r'\[\w+\]')
14
+ # e.g., [Laughter], [Applause], [Music]
15
+ always_split = [
16
+ CustomTokens.MUSIC.value,
17
+ CustomTokens.APPLAUSE.value,
18
+ CustomTokens.LAUGHTER.value
19
+ ]
20
+
21
+
22
+ def get_overlapping_chunks_of_tokens(tokens, size, overlap):
23
+ for i in range(0, len(tokens), size-overlap+1):
24
+ yield tokens[i:i+size]
25
+
26
+
27
+ # Generate up to max_tokens - SAFETY_TOKENS
28
+ SAFETY_TOKENS = 8
29
+
30
+
31
+ # TODO play around with this?
32
+ OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
33
+
34
+
35
+ def add_labels_to_words(words, sponsor_segments):
36
+
37
+ # TODO binary search
38
+ for word in words:
39
+ word['sponsor'] = False
40
+ for sponsor_segment in sponsor_segments:
41
+ if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
42
+ word['sponsor'] = True
43
+
44
+ # TODO use extract_segment with mapping function?
45
+ # TODO remove sponsor segments that contain mostly empty space?
46
+
47
+ return words
48
+
49
+
50
+ def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
51
+ segments = generate_segments(words, tokenizer, segmentation_args)
52
+
53
+ labelled_segments = list(
54
+ map(lambda x: add_labels_to_words(x, sponsor_segments), segments))
55
+
56
+ return labelled_segments
57
+
58
+
59
+ def word_start(word):
60
+ return word['start']
61
+
62
+
63
+ def word_end(word):
64
+ return word.get('end', word['start'])
65
+
66
+
67
+ def generate_segments(words, tokenizer, segmentation_args):
68
+ first_pass_segments = []
69
+
70
+ for index, word in enumerate(words):
71
+ # Get length of tokenized word
72
+ cleaned = preprocess.clean_text(word['text'])
73
+ word['num_tokens'] = len(
74
+ tokenizer(cleaned, add_special_tokens=False, truncation=True).input_ids)
75
+
76
+ add_new_segment = index == 0
77
+ if not add_new_segment:
78
+
79
+ if word['text'] in always_split or words[index-1]['text'] in always_split:
80
+ add_new_segment = True
81
+
82
+ # Pause too small, do not split
83
+ elif word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
84
+ add_new_segment = True
85
+
86
+ if add_new_segment: # New segment
87
+ first_pass_segments.append([word])
88
+
89
+ else: # Add to current segment
90
+ first_pass_segments[-1].append(word)
91
+
92
+ max_q_size = tokenizer.model_max_length - SAFETY_TOKENS
93
+
94
+ buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
95
+
96
+ # In second pass, we split those segments if too big
97
+ second_pass_segments = []
98
+ for segment in first_pass_segments:
99
+ current_segment_num_tokens = 0
100
+ current_segment = []
101
+ for word in segment:
102
+ if current_segment_num_tokens + word['num_tokens'] < max_q_size:
103
+ # Can add tokens to current segment
104
+ current_segment.append(word)
105
+ current_segment_num_tokens += word['num_tokens']
106
+ else:
107
+ # Adding this token would make it have too many tokens
108
+ # We save this batch and create new
109
+ second_pass_segments.append(current_segment.copy())
110
+
111
+ current_segment.append(word)
112
+ current_segment_num_tokens += word['num_tokens']
113
+
114
+ while current_segment_num_tokens > buffer_size and current_segment:
115
+ first_word = current_segment.pop(0)
116
+ current_segment_num_tokens -= first_word['num_tokens']
117
+
118
+ if current_segment:
119
+ second_pass_segments.append(current_segment.copy())
120
+
121
+ return second_pass_segments
122
+
123
+
124
+ def extract_segment(words, start, end, map_function=None):
125
+ """Extract a segment of words that are between (inclusive) the start and end points"""
126
+ segment_words = []
127
+
128
+ if start > end:
129
+ return segment_words
130
+
131
+ # TODO change to binary search
132
+ for w in words: # Assumes words are sorted
133
+ if word_end(w) < start:
134
+ continue # Ignore
135
+ if word_start(w) > end:
136
+ break # Done with range
137
+ if map_function is not None and callable(map_function):
138
+ w = map_function(w)
139
+
140
+ segment_words.append(w)
141
+
142
+ return segment_words
src/shared.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from time import time_ns
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ from typing import Optional
7
+ from dataclasses import dataclass, field
8
+ from enum import Enum
9
+
10
+
11
+ class CustomTokens(Enum):
12
+ URL = 'URL_TOKEN'
13
+ HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
14
+ NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
15
+ NUMBER = 'NUMBER_TOKEN'
16
+
17
+ START_SPONSOR = 'START_SPONSOR'
18
+ END_SPONSOR = 'END_SPONSOR'
19
+ NO_SPONSOR = 'NO_SPONSOR_FOUND'
20
+
21
+ SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
22
+ LONG_WORD = 'LONG_WORD_TOKEN'
23
+
24
+ # Custom YouTube tokens
25
+ MUSIC = '[Music]'
26
+ APPLAUSE = '[Applause]'
27
+ LAUGHTER = '[Laughter]'
28
+
29
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
30
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
31
+ PROFANITY = 'PROFANITY_TOKEN'
32
+
33
+ @classmethod
34
+ def custom_tokens(cls):
35
+ return [e.value for e in cls]
36
+
37
+ @classmethod
38
+ def add_custom_tokens(cls, tokenizer):
39
+ tokenizer.add_tokens(cls.custom_tokens())
40
+
41
+
42
+ @dataclass
43
+ class OutputArguments:
44
+
45
+ output_dir: str = field(
46
+ default='out',
47
+ metadata={
48
+ 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
49
+ },
50
+ )
51
+ checkpoint: Optional[str] = field(
52
+ default=None,
53
+ metadata={
54
+ 'help': 'Choose the checkpoint/model to train from or test with. Defaults to the latest checkpoint found in `output_dir`.'
55
+ },
56
+ )
57
+ models_dir: str = field(
58
+ default='models',
59
+ metadata={
60
+ 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
61
+ },
62
+ )
63
+ # classifier_dir: str = field(
64
+ # default='out',
65
+ # metadata={
66
+ # 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
67
+ # },
68
+ # )
69
+
70
+
71
+ def seed_factory():
72
+ return time_ns() % (2**32 - 1)
73
+
74
+
75
+ @dataclass
76
+ class GeneralArguments:
77
+ seed: Optional[int] = field(default_factory=seed_factory, metadata={
78
+ 'help': 'Set seed for deterministic training and testing. By default, it uses the current time (results in essentially random results).'
79
+ })
80
+
81
+ def __post_init__(self):
82
+ random.seed(self.seed)
83
+ np.random.seed(self.seed)
84
+ torch.manual_seed(self.seed)
85
+ torch.cuda.manual_seed_all(self.seed)
86
+
87
+
88
+ def device():
89
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90
+
91
+
92
+ def reset():
93
+ torch.clear_autocast_cache()
94
+ torch.cuda.empty_cache()
95
+ gc.collect()
96
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
src/train.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from preprocess import load_datasets, DatasetArguments
2
+ from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
3
+ from shared import device
4
+ from shared import GeneralArguments, OutputArguments
5
+ from model import ModelArguments
6
+ import transformers
7
+ import logging
8
+ from model import get_model, get_tokenizer
9
+ import logging
10
+ import os
11
+ import sys
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+ import datasets
15
+ import pickle
16
+ from transformers import (
17
+ DataCollatorForSeq2Seq,
18
+ HfArgumentParser,
19
+ Seq2SeqTrainer,
20
+ Seq2SeqTrainingArguments
21
+ )
22
+ from transformers.trainer_utils import get_last_checkpoint
23
+ from transformers.utils import check_min_version
24
+ from transformers.utils.versions import require_version
25
+ from sklearn.linear_model import LogisticRegression
26
+ from sklearn.feature_extraction.text import TfidfVectorizer
27
+
28
+ import re
29
+
30
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
31
+ check_min_version('4.13.0.dev0')
32
+ require_version('datasets>=1.8.0',
33
+ 'To fix: pip install -r requirements.txt')
34
+
35
+ os.environ['WANDB_DISABLED'] = 'true'
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Setup logging
41
+ logging.basicConfig(
42
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
43
+ datefmt='%m/%d/%Y %H:%M:%S',
44
+ handlers=[logging.StreamHandler(sys.stdout)],
45
+ )
46
+
47
+
48
+ @dataclass
49
+ class DataTrainingArguments:
50
+ """
51
+ Arguments pertaining to what data we are going to input our model for training and eval.
52
+ """
53
+
54
+ preprocessing_num_workers: Optional[int] = field(
55
+ default=None,
56
+ metadata={'help': 'The number of processes to use for the preprocessing.'},
57
+ )
58
+ # https://github.com/huggingface/transformers/issues/5204
59
+ max_source_length: Optional[int] = field(
60
+ default=512,
61
+ metadata={
62
+ 'help': 'The maximum total input sequence length after tokenization. Sequences longer '
63
+ 'than this will be truncated, sequences shorter will be padded.'
64
+ },
65
+ )
66
+ max_target_length: Optional[int] = field(
67
+ default=512,
68
+ metadata={
69
+ 'help': 'The maximum total sequence length for target text after tokenization. Sequences longer '
70
+ 'than this will be truncated, sequences shorter will be padded.'
71
+ },
72
+ )
73
+ val_max_target_length: Optional[int] = field(
74
+ default=None,
75
+ metadata={
76
+ 'help': 'The maximum total sequence length for validation target text after tokenization. Sequences longer '
77
+ 'than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`.'
78
+ 'This argument is also used to override the ``max_length`` param of ``model.generate``, which is used '
79
+ 'during ``evaluate`` and ``predict``.'
80
+ },
81
+ )
82
+ pad_to_max_length: bool = field(
83
+ default=False,
84
+ metadata={
85
+ 'help': 'Whether to pad all samples to model maximum sentence length. '
86
+ 'If False, will pad the samples dynamically when batching to the maximum length in the batch. More '
87
+ 'efficient on GPU but very bad for TPU.'
88
+ },
89
+ )
90
+ max_train_samples: Optional[int] = field(
91
+ default=None,
92
+ metadata={
93
+ 'help': 'For debugging purposes or quicker training, truncate the number of training examples to this value if set.'
94
+ },
95
+ )
96
+ max_eval_samples: Optional[int] = field(
97
+ default=None,
98
+ metadata={
99
+ 'help': 'For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set.'
100
+ },
101
+ )
102
+ max_predict_samples: Optional[int] = field(
103
+ default=None,
104
+ metadata={
105
+ 'help': 'For debugging purposes or quicker training, truncate the number of prediction examples to this value if set.'
106
+ },
107
+ )
108
+ num_beams: Optional[int] = field(
109
+ default=None,
110
+ metadata={
111
+ 'help': 'Number of beams to use for evaluation. This argument will be passed to ``model.generate``, '
112
+ 'which is used during ``evaluate`` and ``predict``.'
113
+ },
114
+ )
115
+ ignore_pad_token_for_loss: bool = field(
116
+ default=True,
117
+ metadata={
118
+ 'help': 'Whether to ignore the tokens corresponding to padded labels in the loss computation or not.'
119
+ },
120
+ )
121
+ source_prefix: Optional[str] = field(
122
+ default=DEFAULT_TOKEN_PREFIX, metadata={
123
+ 'help': 'A prefix to add before every source text (useful for T5 models).'}
124
+ )
125
+
126
+ # TODO add vectorizer params
127
+
128
+ def __post_init__(self):
129
+ if self.val_max_target_length is None:
130
+ self.val_max_target_length = self.max_target_length
131
+
132
+
133
+ @dataclass
134
+ class SequenceTrainingArguments(OutputArguments, Seq2SeqTrainingArguments):
135
+ seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
136
+
137
+ num_train_epochs: float = field(
138
+ default=1, metadata={'help': 'Total number of training epochs to perform.'})
139
+
140
+ save_steps: int = field(default=2500, metadata={
141
+ 'help': 'Save checkpoint every X updates steps.'})
142
+ eval_steps: int = field(default=2500, metadata={
143
+ 'help': 'Run an evaluation every X steps.'})
144
+ logging_steps: int = field(default=2500, metadata={
145
+ 'help': 'Log every X updates steps.'})
146
+
147
+ skip_train_transformer: bool = field(default=False, metadata={
148
+ 'help': 'Whether to skip training the transformer.'})
149
+ train_classifier: bool = field(default=False, metadata={
150
+ 'help': 'Whether to run training on the 2nd phase (classifier).'})
151
+
152
+ # do_eval: bool = field(default=False, metadata={
153
+ # 'help': 'Whether to run eval on the dev set.'})
154
+ do_predict: bool = field(default=False, metadata={
155
+ 'help': 'Whether to run predictions on the test set.'})
156
+
157
+ per_device_train_batch_size: int = field(
158
+ default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
159
+ )
160
+ per_device_eval_batch_size: int = field(
161
+ default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
162
+ )
163
+
164
+ # report_to: Optional[List[str]] = field(
165
+ # default=None, metadata={"help": "The list of integrations to report the results and logs to."}
166
+ # )
167
+ evaluation_strategy: str = field(
168
+ default='steps',
169
+ metadata={
170
+ 'help': 'The evaluation strategy to use.',
171
+ 'choices': ['no', 'steps', 'epoch']
172
+ },
173
+ )
174
+
175
+ # evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
176
+ # The evaluation strategy to adopt during training. Possible values are:
177
+
178
+ # * :obj:`"no"`: No evaluation is done during training.
179
+ # * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
180
+ # * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
181
+
182
+
183
+ def main():
184
+ # reset()
185
+
186
+ # See all possible arguments in src/transformers/training_args.py
187
+ # or by passing the --help flag to this script.
188
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
189
+
190
+ hf_parser = HfArgumentParser((
191
+ ModelArguments,
192
+ DatasetArguments,
193
+ DataTrainingArguments,
194
+ SequenceTrainingArguments,
195
+ ClassifierArguments
196
+ ))
197
+ model_args, dataset_args, data_training_args, training_args, classifier_args = hf_parser.parse_args_into_dataclasses()
198
+
199
+ log_level = training_args.get_process_log_level()
200
+ logger.setLevel(log_level)
201
+ datasets.utils.logging.set_verbosity(log_level)
202
+ transformers.utils.logging.set_verbosity(log_level)
203
+ transformers.utils.logging.enable_default_handler()
204
+ transformers.utils.logging.enable_explicit_format()
205
+
206
+ # Set seed before initializing model.
207
+ # set_seed(training_args.seed)
208
+
209
+ # Log on each process the small summary:
210
+ logger.warning(
211
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
212
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
213
+ )
214
+ logger.info(f'Training/evaluation parameters {training_args}')
215
+
216
+ # FP16 https://github.com/huggingface/transformers/issues/9295
217
+
218
+ # Works:
219
+ # https://huggingface.co/docs/transformers/model_doc/t5v1.1
220
+ # google/t5-v1_1-small
221
+ # google/t5-v1_1-base
222
+ # google/t5-v1_1-large
223
+ # google/t5-v1_1-xl
224
+ # google/t5-v1_1-xxl
225
+
226
+ # https://huggingface.co/docs/transformers/model_doc/t5
227
+ # t5-small
228
+ # t5-base
229
+ # t5-large
230
+ # t5-3b
231
+ # t5-11b
232
+
233
+ # allenai/led-base-16384 - https://github.com/huggingface/transformers/issues/9810
234
+
235
+ # Further work:
236
+ # Multilingual- https://huggingface.co/docs/transformers/model_doc/mt5
237
+
238
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
239
+ # download the dataset.
240
+ if training_args.skip_train_transformer and not training_args.train_classifier:
241
+ print('Nothing to do. Exiting')
242
+ return
243
+
244
+ raw_datasets = load_datasets(dataset_args)
245
+ # , cache_dir=model_args.cache_dir
246
+
247
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
248
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
249
+
250
+ if training_args.train_classifier:
251
+ print('Train classifier')
252
+ # 1. Vectorize raw data to pass into classifier
253
+ # CountVectorizer TfidfVectorizer
254
+ # TfidfVectorizer - better (comb of CountVectorizer)
255
+ vectorizer = TfidfVectorizer( # CountVectorizer
256
+ # lowercase=False,
257
+ # stop_words='english', # TODO optimise stop words?
258
+ # stop_words=stop_words,
259
+
260
+ ngram_range=(1, 2), # best so far
261
+ # max_features=8000 # remove for higher accuracy?
262
+ max_features=50000
263
+ # max_features=10000
264
+ )
265
+
266
+ train_test_data = {
267
+ 'train': {
268
+ 'X': [],
269
+ 'y': []
270
+ },
271
+ 'test': {
272
+ 'X': [],
273
+ 'y': []
274
+ }
275
+ }
276
+
277
+ print('Splitting')
278
+ for ds_type in train_test_data:
279
+ dataset = raw_datasets[ds_type]
280
+
281
+ for row in dataset:
282
+
283
+ # Get matches:
284
+ if row['sponsor']:
285
+ matches = re.findall(SPONSOR_MATCH_RE, row['extracted'])
286
+ else:
287
+ matches = [row['text']]
288
+
289
+ for match in matches:
290
+ train_test_data[ds_type]['X'].append(match)
291
+ train_test_data[ds_type]['y'].append(row['sponsor'])
292
+
293
+ print('Fitting')
294
+ _X_train = vectorizer.fit_transform(train_test_data['train']['X'])
295
+ _X_test = vectorizer.transform(train_test_data['test']['X'])
296
+
297
+ y_train = train_test_data['train']['y']
298
+ y_test = train_test_data['test']['y']
299
+
300
+ # 2. Create classifier
301
+ classifier = LogisticRegression(max_iter=500)
302
+
303
+ # 3. Fit data
304
+ print('fit classifier')
305
+ classifier.fit(_X_train, y_train)
306
+
307
+ # 4. Measure accuracy
308
+ accuracy = classifier.score(_X_test, y_test)
309
+
310
+ print(f'[LogisticRegression] Accuracy percent:',
311
+ round(accuracy*100, 3))
312
+
313
+ # 5. Save classifier and vectorizer
314
+ with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'wb') as fp:
315
+ pickle.dump(classifier, fp)
316
+
317
+ with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'wb') as fp:
318
+ pickle.dump(vectorizer, fp)
319
+
320
+ if not training_args.skip_train_transformer:
321
+
322
+ if data_training_args.source_prefix is None and 't5-' in model_args.model_name_or_path:
323
+ logger.warning(
324
+ "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with `--source_prefix 'summarize: ' `"
325
+ )
326
+
327
+ # Detecting last checkpoint.
328
+ last_checkpoint = None
329
+ if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
330
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
331
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
332
+ raise ValueError(
333
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
334
+ )
335
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
336
+ logger.info(
337
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
338
+ )
339
+
340
+ # Load pretrained model and tokenizer
341
+ tokenizer = get_tokenizer(model_args)
342
+ model = get_model(model_args)
343
+ model.to(device())
344
+ model.resize_token_embeddings(len(tokenizer))
345
+
346
+ if model.config.decoder_start_token_id is None:
347
+ raise ValueError(
348
+ 'Make sure that `config.decoder_start_token_id` is correctly defined')
349
+
350
+ if hasattr(model.config, 'max_position_embeddings') and model.config.max_position_embeddings < data_training_args.max_source_length:
351
+ if model_args.resize_position_embeddings is None:
352
+ logger.warning(
353
+ f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} to {data_training_args.max_source_length}."
354
+ )
355
+ model.resize_position_embeddings(
356
+ data_training_args.max_source_length)
357
+
358
+ elif model_args.resize_position_embeddings:
359
+ model.resize_position_embeddings(
360
+ data_training_args.max_source_length)
361
+
362
+ else:
363
+ raise ValueError(
364
+ f'`--max_source_length` is set to {data_training_args.max_source_length}, but the model only has {model.config.max_position_embeddings}'
365
+ f' position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically '
366
+ "resize the model's position encodings by passing `--resize_position_embeddings`."
367
+ )
368
+
369
+ # Preprocessing the datasets.
370
+ # We need to tokenize inputs and targets.
371
+ column_names = raw_datasets['train'].column_names
372
+
373
+ # Temporarily set max_target_length for training.
374
+ max_target_length = data_training_args.max_target_length
375
+ padding = 'max_length' if data_training_args.pad_to_max_length else False
376
+
377
+ if training_args.label_smoothing_factor > 0 and not hasattr(model, 'prepare_decoder_input_ids_from_labels'):
378
+ logger.warning(
379
+ 'label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for'
380
+ f'`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory'
381
+ )
382
+
383
+ prefix = data_training_args.source_prefix if data_training_args.source_prefix is not None else ''
384
+
385
+ # https://github.com/huggingface/transformers/issues/5204
386
+ def preprocess_function(examples):
387
+ inputs = examples['text']
388
+ targets = examples['extracted']
389
+ inputs = [prefix + inp for inp in inputs]
390
+ model_inputs = tokenizer(
391
+ inputs, max_length=data_training_args.max_source_length, padding=padding, truncation=True)
392
+
393
+ # Setup the tokenizer for targets
394
+ with tokenizer.as_target_tokenizer():
395
+ labels = tokenizer(
396
+ targets, max_length=max_target_length, padding=padding, truncation=True)
397
+
398
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
399
+ # padding in the loss.
400
+ if padding == 'max_length' and data_training_args.ignore_pad_token_for_loss:
401
+ labels['input_ids'] = [
402
+ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels['input_ids']
403
+ ]
404
+ model_inputs['labels'] = labels['input_ids']
405
+
406
+ return model_inputs
407
+
408
+ def prepare_dataset(dataset, desc):
409
+ return dataset.map(
410
+ preprocess_function,
411
+ batched=True,
412
+ num_proc=data_training_args.preprocessing_num_workers,
413
+ remove_columns=column_names,
414
+ load_from_cache_file=not dataset_args.overwrite_cache,
415
+ desc=desc, # tokenizing train dataset
416
+ )
417
+ # train_dataset # TODO shuffle?
418
+
419
+ # if training_args.do_train:
420
+ if 'train' not in raw_datasets: # TODO do checks above?
421
+ raise ValueError('Train dataset missing')
422
+ train_dataset = raw_datasets['train']
423
+ if data_training_args.max_train_samples is not None:
424
+ train_dataset = train_dataset.select(
425
+ range(data_training_args.max_train_samples))
426
+ with training_args.main_process_first(desc='train dataset map pre-processing'):
427
+ train_dataset = prepare_dataset(
428
+ train_dataset, desc='Running tokenizer on train dataset')
429
+
430
+ max_target_length = data_training_args.val_max_target_length
431
+ if 'validation' not in raw_datasets:
432
+ raise ValueError('Validation dataset missing')
433
+ eval_dataset = raw_datasets['validation']
434
+ if data_training_args.max_eval_samples is not None:
435
+ eval_dataset = eval_dataset.select(
436
+ range(data_training_args.max_eval_samples))
437
+ with training_args.main_process_first(desc='validation dataset map pre-processing'):
438
+ eval_dataset = prepare_dataset(
439
+ eval_dataset, desc='Running tokenizer on validation dataset')
440
+
441
+ if 'test' not in raw_datasets:
442
+ raise ValueError('Test dataset missing')
443
+ predict_dataset = raw_datasets['test']
444
+ if data_training_args.max_predict_samples is not None:
445
+ predict_dataset = predict_dataset.select(
446
+ range(data_training_args.max_predict_samples))
447
+ with training_args.main_process_first(desc='prediction dataset map pre-processing'):
448
+ predict_dataset = prepare_dataset(
449
+ predict_dataset, desc='Running tokenizer on prediction dataset')
450
+
451
+ # Data collator
452
+ label_pad_token_id = - \
453
+ 100 if data_training_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
454
+ data_collator = DataCollatorForSeq2Seq(
455
+ tokenizer,
456
+ model=model,
457
+ label_pad_token_id=label_pad_token_id,
458
+ pad_to_multiple_of=8 if training_args.fp16 else None,
459
+ )
460
+
461
+ # Done processing datasets
462
+
463
+ # Initialize our Trainer
464
+ trainer = Seq2SeqTrainer(
465
+ model=model,
466
+ args=training_args,
467
+ train_dataset=train_dataset,
468
+ eval_dataset=eval_dataset,
469
+ tokenizer=tokenizer,
470
+ data_collator=data_collator,
471
+ )
472
+
473
+ # Training
474
+ checkpoint = None
475
+ if training_args.resume_from_checkpoint is not None:
476
+ checkpoint = training_args.resume_from_checkpoint
477
+ elif last_checkpoint is not None:
478
+ checkpoint = last_checkpoint
479
+
480
+ try:
481
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
482
+ trainer.save_model() # Saves the tokenizer too for easy upload
483
+ except KeyboardInterrupt:
484
+ print('Saving model')
485
+ trainer.save_model(os.path.join(
486
+ training_args.output_dir, 'checkpoint-latest')) # TODO use dir
487
+ raise
488
+
489
+ metrics = train_result.metrics
490
+ max_train_samples = data_training_args.max_train_samples or len(
491
+ train_dataset)
492
+ metrics['train_samples'] = min(max_train_samples, len(train_dataset))
493
+
494
+ trainer.log_metrics('train', metrics)
495
+ trainer.save_metrics('train', metrics)
496
+ trainer.save_state()
497
+
498
+ kwargs = {'finetuned_from': model_args.model_name_or_path,
499
+ 'tasks': 'summarization'}
500
+
501
+ if training_args.push_to_hub:
502
+ trainer.push_to_hub(**kwargs)
503
+ else:
504
+ trainer.create_model_card(**kwargs)
505
+
506
+
507
+ if __name__ == '__main__':
508
+ main()
src/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ class Job:
5
+ def __init__(self, function, *args, **kwargs) -> None:
6
+ self.function = function
7
+ self.args = args
8
+ self.kwargs = kwargs
9
+
10
+ self.result = None
11
+
12
+
13
+ class InterruptibleThreadPool:
14
+ def __init__(self,
15
+ num_workers=None,
16
+ loop=None,
17
+ shutdown_message='\nAttempting graceful shutdown, press Ctrl+C again to exit...',
18
+ on_job_complete=None, # Useful for monitoring progress
19
+ raise_after_interrupt=False,
20
+ ) -> None:
21
+ self.num_workers = os.cpu_count() if num_workers is None else num_workers
22
+ self.loop = asyncio.get_event_loop() if loop is None else loop
23
+ self.shutdown_message = shutdown_message
24
+
25
+ self.sem = asyncio.Semaphore(num_workers)
26
+
27
+ self.jobs = []
28
+
29
+ self.on_job_complete = on_job_complete
30
+ self.raise_after_interrupt = raise_after_interrupt
31
+
32
+ async def _sync_to_async(self, job):
33
+ async with self.sem: # Limit number of parallel tasks
34
+ job.result = await self.loop.run_in_executor(None, job.function, *job.args, **job.kwargs)
35
+
36
+ if callable(self.on_job_complete):
37
+ self.on_job_complete(job)
38
+
39
+ return job
40
+
41
+ def add_job(self, job):
42
+ self.jobs.append(job)
43
+
44
+ def run(self):
45
+ try:
46
+ tasks = [
47
+ # creating task starts coroutine
48
+ asyncio.ensure_future(self._sync_to_async(job))
49
+ for job in self.jobs
50
+ ]
51
+
52
+ # https://stackoverflow.com/a/42097478
53
+ self.loop.run_until_complete(
54
+ asyncio.gather(*tasks, return_exceptions=True)
55
+ )
56
+
57
+ except KeyboardInterrupt:
58
+ # Optionally show a message if the shutdown may take a while
59
+ print(self.shutdown_message, flush=True)
60
+
61
+ # Do not show `asyncio.CancelledError` exceptions during shutdown
62
+ # (a lot of these may be generated, skip this if you prefer to see them)
63
+ def shutdown_exception_handler(loop, context):
64
+ if "exception" not in context \
65
+ or not isinstance(context["exception"], asyncio.CancelledError):
66
+ loop.default_exception_handler(context)
67
+ self.loop.set_exception_handler(shutdown_exception_handler)
68
+
69
+ # Handle shutdown gracefully by waiting for all tasks to be cancelled
70
+ cancelled_tasks = asyncio.gather(
71
+ *asyncio.all_tasks(loop=self.loop), loop=self.loop, return_exceptions=True)
72
+ cancelled_tasks.add_done_callback(lambda t: self.loop.stop())
73
+ cancelled_tasks.cancel()
74
+
75
+ # Keep the event loop running until it is either destroyed or all
76
+ # tasks have really terminated
77
+ while not cancelled_tasks.done() and not self.loop.is_closed():
78
+ self.loop.run_forever()
79
+
80
+ if self.raise_after_interrupt:
81
+ raise
82
+ finally:
83
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
84
+ self.loop.close()
85
+
86
+ return self.jobs