Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
5fbdd3c
1
Parent(s):
5f40236
Add source code
Browse files- src/errors.py +13 -0
- src/evaluate.py +244 -0
- src/model.py +111 -0
- src/predict.py +278 -0
- src/preprocess.py +786 -0
- src/segment.py +142 -0
- src/shared.py +96 -0
- src/train.py +508 -0
- src/utils.py +86 -0
src/errors.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class SponsorBlockException(Exception):
|
2 |
+
"""Base class for all sponsor block exceptions"""
|
3 |
+
pass
|
4 |
+
|
5 |
+
|
6 |
+
class PredictionException(SponsorBlockException):
|
7 |
+
"""An exception was occurred while predicting sponsor segments"""
|
8 |
+
pass
|
9 |
+
|
10 |
+
|
11 |
+
class TranscriptError(SponsorBlockException):
|
12 |
+
"""An exception was occurred while retrieving the video transcript"""
|
13 |
+
pass
|
src/evaluate.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import (
|
3 |
+
AutoModelForSeq2SeqLM,
|
4 |
+
AutoTokenizer,
|
5 |
+
HfArgumentParser
|
6 |
+
)
|
7 |
+
from preprocess import DatasetArguments, ProcessedArguments, get_words
|
8 |
+
from model import get_classifier_vectorizer
|
9 |
+
from shared import device
|
10 |
+
from predict import ClassifierArguments, PredictArguments, predict, filter_predictions
|
11 |
+
from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
|
12 |
+
import pandas as pd
|
13 |
+
from dataclasses import dataclass, field
|
14 |
+
from typing import Optional
|
15 |
+
from tqdm import tqdm
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import random
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class EvaluationArguments:
|
23 |
+
"""
|
24 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
25 |
+
"""
|
26 |
+
max_videos: Optional[int] = field(
|
27 |
+
default=None,
|
28 |
+
metadata={
|
29 |
+
'help': 'The number of videos to test on'
|
30 |
+
}
|
31 |
+
)
|
32 |
+
model_path: Optional[str] = PredictArguments.__dataclass_fields__[
|
33 |
+
'model_path']
|
34 |
+
data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
|
35 |
+
dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
|
36 |
+
'validation_file']
|
37 |
+
|
38 |
+
output_file: Optional[str] = field(
|
39 |
+
default='metrics.csv',
|
40 |
+
metadata={
|
41 |
+
'help': 'Save metrics to output file'
|
42 |
+
}
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def jaccard(x1, x2, y1, y2):
|
47 |
+
# Calculate jaccard index
|
48 |
+
intersection = max(0, min(x2, y2)-max(x1, y1))
|
49 |
+
filled_union = max(x2, y2) - min(x1, y1)
|
50 |
+
return intersection/filled_union
|
51 |
+
|
52 |
+
|
53 |
+
def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
|
54 |
+
"""Attach sponsor segments to closest prediction"""
|
55 |
+
for prediction in predictions:
|
56 |
+
prediction['best_overlap'] = 0
|
57 |
+
prediction['best_sponsorship'] = None
|
58 |
+
|
59 |
+
# Assign predictions to actual (labelled) sponsored segments
|
60 |
+
for sponsor_segment in sponsor_segments:
|
61 |
+
sponsor_segment['best_overlap'] = 0
|
62 |
+
sponsor_segment['best_prediction'] = None
|
63 |
+
|
64 |
+
for prediction in predictions:
|
65 |
+
|
66 |
+
j = jaccard(prediction['start'], prediction['end'],
|
67 |
+
sponsor_segment['start'], sponsor_segment['end'])
|
68 |
+
if sponsor_segment['best_overlap'] < j:
|
69 |
+
sponsor_segment['best_overlap'] = j
|
70 |
+
sponsor_segment['best_prediction'] = prediction
|
71 |
+
|
72 |
+
if prediction['best_overlap'] < j:
|
73 |
+
prediction['best_overlap'] = j
|
74 |
+
prediction['best_sponsorship'] = sponsor_segment
|
75 |
+
|
76 |
+
return sponsor_segments
|
77 |
+
|
78 |
+
|
79 |
+
def calculate_metrics(labelled_words, predictions):
|
80 |
+
|
81 |
+
metrics = {
|
82 |
+
'true_positive': 0, # Is sponsor, predicted sponsor
|
83 |
+
# Is sponsor, predicted not sponsor (i.e., missed it - bad)
|
84 |
+
'false_negative': 0,
|
85 |
+
# Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
|
86 |
+
'false_positive': 0,
|
87 |
+
'true_negative': 0, # Is not sponsor, predicted not sponsor
|
88 |
+
}
|
89 |
+
|
90 |
+
metrics['video_duration'] = word_end(
|
91 |
+
labelled_words[-1])-word_start(labelled_words[0])
|
92 |
+
|
93 |
+
for index, word in enumerate(labelled_words):
|
94 |
+
if index >= len(labelled_words) - 1:
|
95 |
+
continue
|
96 |
+
|
97 |
+
# TODO make sure words with manual transcripts
|
98 |
+
duration = labelled_words[index+1]['start'] - word['start']
|
99 |
+
|
100 |
+
predicted_sponsor = False
|
101 |
+
for p in predictions:
|
102 |
+
# Is in some prediction
|
103 |
+
if p['start'] <= word['start'] <= p['end']:
|
104 |
+
predicted_sponsor = True
|
105 |
+
break
|
106 |
+
|
107 |
+
if predicted_sponsor:
|
108 |
+
# total_positive_time += duration
|
109 |
+
if word['sponsor']: # Is actual sponsor
|
110 |
+
metrics['true_positive'] += duration
|
111 |
+
else:
|
112 |
+
metrics['false_positive'] += duration
|
113 |
+
else:
|
114 |
+
# total_negative_time += duration
|
115 |
+
if word['sponsor']: # Is actual sponsor
|
116 |
+
metrics['false_negative'] += duration
|
117 |
+
else:
|
118 |
+
metrics['true_negative'] += duration
|
119 |
+
|
120 |
+
# NOTE In cases where we encounter division by 0, we say that the value is 1
|
121 |
+
# https://stats.stackexchange.com/a/1775
|
122 |
+
# (Precision) TP+FP=0: means that all instances were predicted as negative
|
123 |
+
# (Recall) TP+FN=0: means that there were no positive cases in the input data
|
124 |
+
|
125 |
+
# The fraction of predictions our model got right
|
126 |
+
# Can simplify, but use full formula
|
127 |
+
z = metrics['true_positive'] + metrics['true_negative'] + \
|
128 |
+
metrics['false_positive'] + metrics['false_negative']
|
129 |
+
metrics['accuracy'] = (
|
130 |
+
(metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1
|
131 |
+
|
132 |
+
# What proportion of positive identifications was actually correct?
|
133 |
+
z = metrics['true_positive'] + metrics['false_positive']
|
134 |
+
metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1
|
135 |
+
|
136 |
+
# What proportion of actual positives was identified correctly?
|
137 |
+
z = metrics['true_positive'] + metrics['false_negative']
|
138 |
+
metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1
|
139 |
+
|
140 |
+
# https://deepai.org/machine-learning-glossary-and-terms/f-score
|
141 |
+
|
142 |
+
s = metrics['precision'] + metrics['recall']
|
143 |
+
metrics['f-score'] = (2 * (metrics['precision'] *
|
144 |
+
metrics['recall']) / s) if s > 0 else 0
|
145 |
+
|
146 |
+
return metrics
|
147 |
+
|
148 |
+
|
149 |
+
def main():
|
150 |
+
hf_parser = HfArgumentParser((
|
151 |
+
EvaluationArguments,
|
152 |
+
ProcessedArguments,
|
153 |
+
SegmentationArguments,
|
154 |
+
ClassifierArguments
|
155 |
+
))
|
156 |
+
|
157 |
+
evaluation_args, processed_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
|
158 |
+
|
159 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(evaluation_args.model_path)
|
160 |
+
model.to(device())
|
161 |
+
|
162 |
+
tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
|
163 |
+
|
164 |
+
dataset = load_dataset('json', data_files=os.path.join(
|
165 |
+
evaluation_args.data_dir, evaluation_args.dataset))['train']
|
166 |
+
|
167 |
+
video_ids = [row['video_id'] for row in dataset]
|
168 |
+
random.shuffle(video_ids) # TODO Make param
|
169 |
+
|
170 |
+
if evaluation_args.max_videos is not None:
|
171 |
+
video_ids = video_ids[:evaluation_args.max_videos]
|
172 |
+
|
173 |
+
# Load labelled data:
|
174 |
+
final_path = os.path.join(
|
175 |
+
processed_args.processed_dir, processed_args.processed_file)
|
176 |
+
|
177 |
+
with open(final_path) as fp:
|
178 |
+
final_data = json.load(fp)
|
179 |
+
|
180 |
+
classifier, vectorizer = get_classifier_vectorizer(classifier_args)
|
181 |
+
|
182 |
+
total_accuracy = 0
|
183 |
+
total_precision = 0
|
184 |
+
total_recall = 0
|
185 |
+
total_fscore = 0
|
186 |
+
|
187 |
+
count = 0
|
188 |
+
out_metrics = []
|
189 |
+
|
190 |
+
try:
|
191 |
+
with tqdm(video_ids) as progress:
|
192 |
+
for video_id in progress:
|
193 |
+
progress.set_description(f'Processing {video_id}')
|
194 |
+
sponsor_segments = final_data.get(video_id, [])
|
195 |
+
|
196 |
+
words = get_words(video_id)
|
197 |
+
if not words:
|
198 |
+
continue
|
199 |
+
|
200 |
+
count += 1
|
201 |
+
|
202 |
+
# Make predictions
|
203 |
+
predictions = predict(video_id, model, tokenizer,
|
204 |
+
segmentation_args, words)
|
205 |
+
|
206 |
+
# Filter predictions
|
207 |
+
predictions = filter_predictions(
|
208 |
+
predictions, classifier, vectorizer, classifier_args)
|
209 |
+
|
210 |
+
labelled_words = add_labels_to_words(words, sponsor_segments)
|
211 |
+
met = calculate_metrics(labelled_words, predictions)
|
212 |
+
met['video_id'] = video_id
|
213 |
+
|
214 |
+
out_metrics.append(met)
|
215 |
+
|
216 |
+
total_accuracy += met['accuracy']
|
217 |
+
total_precision += met['precision']
|
218 |
+
total_recall += met['recall']
|
219 |
+
total_fscore += met['f-score']
|
220 |
+
|
221 |
+
progress.set_postfix({
|
222 |
+
'accuracy': total_accuracy/count,
|
223 |
+
'precision': total_precision/count,
|
224 |
+
'recall': total_recall/count,
|
225 |
+
'f-score': total_fscore/count
|
226 |
+
})
|
227 |
+
|
228 |
+
labelled_predicted_segments = attach_predictions_to_sponsor_segments(
|
229 |
+
predictions, sponsor_segments)
|
230 |
+
for seg in labelled_predicted_segments:
|
231 |
+
if seg['best_prediction'] is None:
|
232 |
+
print('\nNo match found for', seg)
|
233 |
+
|
234 |
+
except KeyboardInterrupt:
|
235 |
+
pass
|
236 |
+
|
237 |
+
df = pd.DataFrame(out_metrics)
|
238 |
+
|
239 |
+
df.to_csv(evaluation_args.output_file)
|
240 |
+
print(df.mean())
|
241 |
+
|
242 |
+
|
243 |
+
if __name__ == '__main__':
|
244 |
+
main()
|
src/model.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import os
|
3 |
+
from shared import CustomTokens
|
4 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class ModelArguments:
|
11 |
+
"""
|
12 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
13 |
+
"""
|
14 |
+
|
15 |
+
model_name_or_path: str = field(
|
16 |
+
default='google/t5-v1_1-small', # t5-small
|
17 |
+
metadata={
|
18 |
+
'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
|
19 |
+
)
|
20 |
+
# config_name: Optional[str] = field( # TODO remove?
|
21 |
+
# default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
|
22 |
+
# )
|
23 |
+
tokenizer_name: Optional[str] = field(
|
24 |
+
default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}
|
25 |
+
)
|
26 |
+
cache_dir: Optional[str] = field(
|
27 |
+
default=None,
|
28 |
+
metadata={
|
29 |
+
'help': 'Where to store the pretrained models downloaded from huggingface.co'},
|
30 |
+
)
|
31 |
+
use_fast_tokenizer: bool = field( # TODO remove?
|
32 |
+
default=True,
|
33 |
+
metadata={
|
34 |
+
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'},
|
35 |
+
)
|
36 |
+
model_revision: str = field( # TODO remove?
|
37 |
+
default='main',
|
38 |
+
metadata={
|
39 |
+
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'},
|
40 |
+
)
|
41 |
+
use_auth_token: bool = field(
|
42 |
+
default=False,
|
43 |
+
metadata={
|
44 |
+
'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
|
45 |
+
'with private models).'
|
46 |
+
},
|
47 |
+
)
|
48 |
+
resize_position_embeddings: Optional[bool] = field(
|
49 |
+
default=None,
|
50 |
+
metadata={
|
51 |
+
'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
|
52 |
+
},
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
def get_model(model_args, use_cache=True):
|
57 |
+
name = model_args.model_name_or_path
|
58 |
+
cached_path = f'models/{name}'
|
59 |
+
|
60 |
+
# Model created after tokenizer:
|
61 |
+
if use_cache and os.path.exists(os.path.join(cached_path, 'pytorch_model.bin')):
|
62 |
+
name = cached_path
|
63 |
+
|
64 |
+
config = AutoConfig.from_pretrained(
|
65 |
+
name,
|
66 |
+
cache_dir=model_args.cache_dir,
|
67 |
+
revision=model_args.model_revision,
|
68 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
69 |
+
)
|
70 |
+
|
71 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
72 |
+
name,
|
73 |
+
from_tf='.ckpt' in name,
|
74 |
+
config=config,
|
75 |
+
cache_dir=model_args.cache_dir,
|
76 |
+
revision=model_args.model_revision,
|
77 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
78 |
+
)
|
79 |
+
|
80 |
+
return model
|
81 |
+
|
82 |
+
|
83 |
+
def get_tokenizer(model_args, use_cache=True):
|
84 |
+
name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
|
85 |
+
|
86 |
+
cached_path = f'models/{name}'
|
87 |
+
|
88 |
+
if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
|
89 |
+
name = cached_path
|
90 |
+
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
92 |
+
name,
|
93 |
+
cache_dir=model_args.cache_dir,
|
94 |
+
use_fast=model_args.use_fast_tokenizer,
|
95 |
+
revision=model_args.model_revision,
|
96 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
97 |
+
)
|
98 |
+
|
99 |
+
CustomTokens.add_custom_tokens(tokenizer)
|
100 |
+
|
101 |
+
return tokenizer
|
102 |
+
|
103 |
+
|
104 |
+
def get_classifier_vectorizer(classifier_args):
|
105 |
+
with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'rb') as fp:
|
106 |
+
classifier = pickle.load(fp)
|
107 |
+
|
108 |
+
with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'rb') as fp:
|
109 |
+
vectorizer = pickle.load(fp)
|
110 |
+
|
111 |
+
return classifier, vectorizer
|
src/predict.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from shared import OutputArguments
|
2 |
+
from typing import Optional
|
3 |
+
from segment import (
|
4 |
+
generate_segments,
|
5 |
+
extract_segment,
|
6 |
+
SAFETY_TOKENS,
|
7 |
+
CustomTokens,
|
8 |
+
word_start,
|
9 |
+
word_end,
|
10 |
+
SegmentationArguments
|
11 |
+
)
|
12 |
+
import preprocess
|
13 |
+
import re
|
14 |
+
from errors import TranscriptError
|
15 |
+
from model import get_classifier_vectorizer
|
16 |
+
from transformers import (
|
17 |
+
AutoModelForSeq2SeqLM,
|
18 |
+
AutoTokenizer
|
19 |
+
)
|
20 |
+
from dataclasses import dataclass, field
|
21 |
+
from transformers import HfArgumentParser
|
22 |
+
from shared import device
|
23 |
+
import logging
|
24 |
+
from transformers.trainer_utils import get_last_checkpoint
|
25 |
+
|
26 |
+
|
27 |
+
def seconds_to_time(seconds):
|
28 |
+
h, remainder = divmod(abs(int(seconds)), 3600)
|
29 |
+
m, s = divmod(remainder, 60)
|
30 |
+
return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}"
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class PredictArguments:
|
35 |
+
|
36 |
+
video_id: str = field(
|
37 |
+
metadata={
|
38 |
+
'help': 'Video to predict sponsorship segments for'}
|
39 |
+
)
|
40 |
+
|
41 |
+
model_path: str = field(
|
42 |
+
default=None,
|
43 |
+
metadata={
|
44 |
+
'help': 'Path to pretrained model used for prediction'}
|
45 |
+
)
|
46 |
+
|
47 |
+
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
48 |
+
'output_dir']
|
49 |
+
|
50 |
+
def __post_init__(self):
|
51 |
+
if self.model_path is not None:
|
52 |
+
return
|
53 |
+
|
54 |
+
last_checkpoint = get_last_checkpoint(self.output_dir)
|
55 |
+
if last_checkpoint is not None:
|
56 |
+
self.model_path = last_checkpoint
|
57 |
+
else:
|
58 |
+
raise Exception(
|
59 |
+
'Unable to find model, explicitly set `--model_path`')
|
60 |
+
|
61 |
+
|
62 |
+
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
|
63 |
+
|
64 |
+
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
65 |
+
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class ClassifierArguments:
|
70 |
+
classifier_dir: Optional[str] = field(
|
71 |
+
default='classifiers',
|
72 |
+
metadata={
|
73 |
+
'help': 'The directory that contains the classifier and vectorizer.'
|
74 |
+
}
|
75 |
+
)
|
76 |
+
|
77 |
+
classifier_file: Optional[str] = field(
|
78 |
+
default='classifier.pickle',
|
79 |
+
metadata={
|
80 |
+
'help': 'The name of the classifier'
|
81 |
+
}
|
82 |
+
)
|
83 |
+
|
84 |
+
vectorizer_file: Optional[str] = field(
|
85 |
+
default='vectorizer.pickle',
|
86 |
+
metadata={
|
87 |
+
'help': 'The name of the vectorizer'
|
88 |
+
}
|
89 |
+
)
|
90 |
+
|
91 |
+
min_probability: float = field(
|
92 |
+
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
93 |
+
|
94 |
+
|
95 |
+
def filter_predictions(predictions, classifier, vectorizer, classifier_args):
|
96 |
+
"""Use classifier to filter predictions"""
|
97 |
+
if not predictions:
|
98 |
+
return predictions
|
99 |
+
|
100 |
+
transformed_segments = vectorizer.transform([
|
101 |
+
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
102 |
+
for pred in predictions
|
103 |
+
])
|
104 |
+
probabilities = classifier.predict_proba(transformed_segments)
|
105 |
+
|
106 |
+
filtered_predictions = []
|
107 |
+
for prediction, probability in zip(predictions, probabilities):
|
108 |
+
prediction['probability'] = probability[1]
|
109 |
+
|
110 |
+
if prediction['probability'] >= classifier_args.min_probability:
|
111 |
+
filtered_predictions.append(prediction)
|
112 |
+
# else:
|
113 |
+
# print('removing segment', prediction)
|
114 |
+
|
115 |
+
return filtered_predictions
|
116 |
+
|
117 |
+
|
118 |
+
def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier_args=None):
|
119 |
+
# Allow words to be passed in so that we don't have to get the words if we already have them
|
120 |
+
if words is None:
|
121 |
+
words = preprocess.get_words(video_id)
|
122 |
+
if not words:
|
123 |
+
raise TranscriptError('Unable to retrieve transcript')
|
124 |
+
|
125 |
+
segments = generate_segments(
|
126 |
+
words,
|
127 |
+
tokenizer,
|
128 |
+
segmentation_args
|
129 |
+
)
|
130 |
+
|
131 |
+
predictions = segments_to_prediction_times(segments, model, tokenizer)
|
132 |
+
|
133 |
+
# Add words back to time_ranges
|
134 |
+
for prediction in predictions:
|
135 |
+
# Stores words in the range
|
136 |
+
prediction['words'] = extract_segment(
|
137 |
+
words, prediction['start'], prediction['end'])
|
138 |
+
|
139 |
+
if classifier_args is not None:
|
140 |
+
classifier, vectorizer = get_classifier_vectorizer(classifier_args)
|
141 |
+
predictions = filter_predictions(
|
142 |
+
predictions, classifier, vectorizer, classifier_args)
|
143 |
+
|
144 |
+
return predictions
|
145 |
+
|
146 |
+
|
147 |
+
def greedy_match(list, sublist):
|
148 |
+
# Return index and length of longest matching sublist
|
149 |
+
|
150 |
+
best_i = -1
|
151 |
+
best_j = -1
|
152 |
+
best_k = 0
|
153 |
+
|
154 |
+
for i in range(len(list)): # Start position in main list
|
155 |
+
for j in range(len(sublist)): # Start position in sublist
|
156 |
+
for k in range(len(sublist)-j, 0, -1): # Width of sublist window
|
157 |
+
if k > best_k and list[i:i+k] == sublist[j:j+k]:
|
158 |
+
best_i, best_j, best_k = i, j, k
|
159 |
+
break # Since window size decreases
|
160 |
+
|
161 |
+
return best_i, best_j, best_k
|
162 |
+
|
163 |
+
|
164 |
+
DEFAULT_TOKEN_PREFIX = 'summarize: '
|
165 |
+
|
166 |
+
|
167 |
+
def predict_sponsor_text(text, model, tokenizer):
|
168 |
+
"""Given a body of text, predict the words which are part of the sponsor"""
|
169 |
+
input_ids = tokenizer(
|
170 |
+
f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids
|
171 |
+
|
172 |
+
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
173 |
+
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|
174 |
+
outputs = model.generate(input_ids, max_length=max_out_len)
|
175 |
+
|
176 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
177 |
+
|
178 |
+
|
179 |
+
def predict_sponsor_matches(text, model, tokenizer):
|
180 |
+
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
181 |
+
if CustomTokens.NO_SPONSOR.value in sponsorship_text:
|
182 |
+
return []
|
183 |
+
|
184 |
+
return re.findall(SPONSOR_MATCH_RE, sponsorship_text)
|
185 |
+
|
186 |
+
|
187 |
+
def segments_to_prediction_times(segments, model, tokenizer):
|
188 |
+
predicted_time_ranges = []
|
189 |
+
|
190 |
+
# TODO pass to model simultaneously, not in for loop
|
191 |
+
# use 2d array for input ids
|
192 |
+
for segment in segments:
|
193 |
+
cleaned_batch = [preprocess.clean_text(
|
194 |
+
word['text']) for word in segment]
|
195 |
+
batch_text = ' '.join(cleaned_batch)
|
196 |
+
|
197 |
+
matches = predict_sponsor_matches(batch_text, model, tokenizer)
|
198 |
+
|
199 |
+
for match in matches:
|
200 |
+
matched_text = match.split()
|
201 |
+
# TODO skip if too short
|
202 |
+
|
203 |
+
i1, j1, k1 = greedy_match(
|
204 |
+
cleaned_batch, matched_text[:MATCH_WINDOW])
|
205 |
+
i2, j2, k2 = greedy_match(
|
206 |
+
cleaned_batch, matched_text[-MATCH_WINDOW:])
|
207 |
+
|
208 |
+
extracted_words = segment[i1:i2+k2]
|
209 |
+
|
210 |
+
if not extracted_words:
|
211 |
+
continue
|
212 |
+
|
213 |
+
predicted_time_ranges.append({
|
214 |
+
'start': word_start(extracted_words[0]),
|
215 |
+
'end': word_end(extracted_words[-1])
|
216 |
+
})
|
217 |
+
|
218 |
+
# Necessary to sort matches by start time
|
219 |
+
predicted_time_ranges.sort(key=word_start)
|
220 |
+
|
221 |
+
# Merge overlapping predictions and sponsorships that are close together
|
222 |
+
# Caused by model having max input size
|
223 |
+
last_end_time = -1
|
224 |
+
final_predicted_time_ranges = []
|
225 |
+
for range in predicted_time_ranges:
|
226 |
+
start_time = range['start']
|
227 |
+
end_time = range['end']
|
228 |
+
|
229 |
+
if (start_time <= last_end_time <= end_time) or (last_end_time != -1 and start_time - last_end_time <= MERGE_TIME_WITHIN):
|
230 |
+
# Ending time of last segment is in this segment, so we extend last prediction range
|
231 |
+
final_predicted_time_ranges[-1]['end'] = end_time
|
232 |
+
|
233 |
+
else: # No overlap, is a new prediction
|
234 |
+
final_predicted_time_ranges.append({
|
235 |
+
'start': start_time,
|
236 |
+
'end': end_time,
|
237 |
+
})
|
238 |
+
|
239 |
+
last_end_time = end_time
|
240 |
+
|
241 |
+
return final_predicted_time_ranges
|
242 |
+
|
243 |
+
|
244 |
+
def main():
|
245 |
+
# Test on unseen data
|
246 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
247 |
+
|
248 |
+
hf_parser = HfArgumentParser((
|
249 |
+
PredictArguments,
|
250 |
+
SegmentationArguments,
|
251 |
+
ClassifierArguments
|
252 |
+
))
|
253 |
+
predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
|
254 |
+
|
255 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
|
256 |
+
model.to(device())
|
257 |
+
|
258 |
+
tokenizer = AutoTokenizer.from_pretrained(predict_args.model_path)
|
259 |
+
|
260 |
+
predict_args.video_id = predict_args.video_id.strip()
|
261 |
+
print(
|
262 |
+
f'Predicting for https://www.youtube.com/watch?v={predict_args.video_id}')
|
263 |
+
predictions = predict(predict_args.video_id, model, tokenizer,
|
264 |
+
segmentation_args, classifier_args=classifier_args)
|
265 |
+
|
266 |
+
for prediction in predictions:
|
267 |
+
print(' '.join([w['text'] for w in prediction['words']]))
|
268 |
+
print(seconds_to_time(prediction['start']),
|
269 |
+
'-->', seconds_to_time(prediction['end']))
|
270 |
+
print(prediction['start'], '-->', prediction['end'])
|
271 |
+
print(prediction['probability'])
|
272 |
+
print()
|
273 |
+
|
274 |
+
print()
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == '__main__':
|
278 |
+
main()
|
src/preprocess.py
ADDED
@@ -0,0 +1,786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from typing import Optional
|
3 |
+
from datasets import load_dataset
|
4 |
+
from model import ModelArguments
|
5 |
+
import segment
|
6 |
+
from tqdm import tqdm
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from transformers import HfArgumentParser
|
9 |
+
from shared import GeneralArguments, CustomTokens
|
10 |
+
import csv
|
11 |
+
import re
|
12 |
+
import random
|
13 |
+
import logging
|
14 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
15 |
+
from youtube_transcript_api._errors import CouldNotRetrieveTranscript, YouTubeRequestFailed
|
16 |
+
import os
|
17 |
+
import json
|
18 |
+
import time
|
19 |
+
import requests
|
20 |
+
from utils import InterruptibleThreadPool, Job
|
21 |
+
|
22 |
+
|
23 |
+
def find(s, ch):
|
24 |
+
return [i for i, ltr in enumerate(s) if ltr == ch]
|
25 |
+
|
26 |
+
|
27 |
+
def wordify(transcript):
|
28 |
+
"""Try to replicate format for automatically generated transcripts"""
|
29 |
+
words = []
|
30 |
+
|
31 |
+
for line_index, line in enumerate(transcript):
|
32 |
+
text = line['text'].replace('\n', ' ').strip()
|
33 |
+
if not text:
|
34 |
+
continue
|
35 |
+
|
36 |
+
start = line['start']
|
37 |
+
next_start = transcript[line_index +
|
38 |
+
1]['start'] if line_index < len(transcript) - 1 else float('inf')
|
39 |
+
end = min(start + line['duration'], next_start)
|
40 |
+
duration = end - start
|
41 |
+
|
42 |
+
indices = find(text, ' ') + [len(text)]
|
43 |
+
start_index = 0
|
44 |
+
for i in range(len(indices)):
|
45 |
+
word = text[start_index:indices[i]].strip()
|
46 |
+
if not word:
|
47 |
+
continue # Skip empty words (e.g., \n)
|
48 |
+
percentage = start_index/indices[-1]
|
49 |
+
|
50 |
+
w_duration = len(word)/indices[-1] * duration
|
51 |
+
|
52 |
+
w_start = start + percentage * duration
|
53 |
+
|
54 |
+
words.append({
|
55 |
+
'start': round(w_start, 5),
|
56 |
+
'duration': round(w_duration, 5),
|
57 |
+
'end': round(w_start + w_duration, 5),
|
58 |
+
'text': word,
|
59 |
+
})
|
60 |
+
|
61 |
+
start_index = indices[i] + 1
|
62 |
+
|
63 |
+
return words
|
64 |
+
|
65 |
+
|
66 |
+
def get_manual_words(transcript_list):
|
67 |
+
transcript = transcript_list.find_manually_created_transcript(
|
68 |
+
['en-GB', 'en-US', 'en']).fetch()
|
69 |
+
return wordify(transcript)
|
70 |
+
|
71 |
+
|
72 |
+
def get_auto_words(transcript_list):
|
73 |
+
words = []
|
74 |
+
transcript = transcript_list.find_generated_transcript(['en'])
|
75 |
+
url = transcript._url + '&fmt=json3'
|
76 |
+
info = transcript._http_client.get(url)
|
77 |
+
|
78 |
+
for event in info.json()['events']:
|
79 |
+
start_ms = event.get('tStartMs', 0)
|
80 |
+
|
81 |
+
for word in event.get('segs') or []:
|
82 |
+
offset_ms = word.get('tOffsetMs', 0)
|
83 |
+
|
84 |
+
texts = word['utf8'].replace(
|
85 |
+
CustomTokens.PROFANITY_RAW.value, CustomTokens.PROFANITY_CONVERTED.value
|
86 |
+
).strip().split()
|
87 |
+
|
88 |
+
for text in texts:
|
89 |
+
words.append({
|
90 |
+
'start': (start_ms + offset_ms)/1000,
|
91 |
+
'text': text
|
92 |
+
})
|
93 |
+
|
94 |
+
return words
|
95 |
+
|
96 |
+
|
97 |
+
def get_words(video_id, process=True, fallback=False, transcript_type='auto'):
|
98 |
+
"""Get parsed video transcript with caching system
|
99 |
+
returns None if not processed yet and process is False
|
100 |
+
"""
|
101 |
+
get_manual_if_fail = fallback and transcript_type == 'auto'
|
102 |
+
transcript_path = os.path.join(
|
103 |
+
'transcripts', transcript_type, f'{video_id}.json')
|
104 |
+
words = []
|
105 |
+
try:
|
106 |
+
if os.path.exists(transcript_path):
|
107 |
+
with open(transcript_path) as fp:
|
108 |
+
wds = json.load(fp)
|
109 |
+
|
110 |
+
if not wds and get_manual_if_fail:
|
111 |
+
return get_words(video_id, process, fallback, 'manual')
|
112 |
+
return wds
|
113 |
+
|
114 |
+
elif not process:
|
115 |
+
return None
|
116 |
+
|
117 |
+
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
|
118 |
+
|
119 |
+
if transcript_type == 'manual':
|
120 |
+
words = get_manual_words(transcript_list)
|
121 |
+
else:
|
122 |
+
words = get_auto_words(transcript_list)
|
123 |
+
|
124 |
+
except YouTubeRequestFailed as e:
|
125 |
+
print(e)
|
126 |
+
time.sleep(30) # Timeout
|
127 |
+
return get_words(video_id, process, fallback, transcript_type)
|
128 |
+
|
129 |
+
except CouldNotRetrieveTranscript:
|
130 |
+
if get_manual_if_fail:
|
131 |
+
print('fallback')
|
132 |
+
return get_words(video_id, process, fallback, 'manual')
|
133 |
+
|
134 |
+
except json.decoder.JSONDecodeError:
|
135 |
+
# Warning, unable to parse JSON
|
136 |
+
pass
|
137 |
+
|
138 |
+
with open(transcript_path, 'w') as fp:
|
139 |
+
json.dump(words, fp)
|
140 |
+
|
141 |
+
return words
|
142 |
+
|
143 |
+
|
144 |
+
# TODO make min_sponsor_segment_length param
|
145 |
+
def extract_sponsors(words, min_sponsor_segment_length=5):
|
146 |
+
if len(words) < min_sponsor_segment_length:
|
147 |
+
return [] # Force short phrases to not be sponsors
|
148 |
+
|
149 |
+
paragraphs = []
|
150 |
+
current = []
|
151 |
+
for word in words:
|
152 |
+
if not word.get('sponsor') and not current:
|
153 |
+
continue
|
154 |
+
|
155 |
+
if word['sponsor']:
|
156 |
+
current.append(word['text'])
|
157 |
+
else:
|
158 |
+
paragraphs.append(current)
|
159 |
+
current = []
|
160 |
+
if current:
|
161 |
+
paragraphs.append(current)
|
162 |
+
|
163 |
+
# Remove all too short:
|
164 |
+
paragraphs = list(filter(lambda x: len(
|
165 |
+
x) >= min_sponsor_segment_length, paragraphs))
|
166 |
+
|
167 |
+
return paragraphs
|
168 |
+
|
169 |
+
|
170 |
+
def clean_text(text):
|
171 |
+
|
172 |
+
# Replace impossibly long words with a special token
|
173 |
+
# Usually the result of incorrect labelling
|
174 |
+
text = re.sub(r'\w{64,}', CustomTokens.LONG_WORD.value, text)
|
175 |
+
|
176 |
+
SHORT_HYPHENATED_REGEX = r'\w{1,2}(?:-\w{1,2}){3,}(?:-?\w*)'
|
177 |
+
|
178 |
+
# Replace hyphenated URLs with special token
|
179 |
+
# For some reason, youtube sometimes transcribes urls in this form:
|
180 |
+
# 'b-a-b-b-e-l-dot-com', 'g-e-t-r-o-m-a-n-com'
|
181 |
+
# not 'e-commerce'
|
182 |
+
text = re.sub(f'{SHORT_HYPHENATED_REGEX}(?:com|org|net)',
|
183 |
+
CustomTokens.HYPHENATED_URL.value, text)
|
184 |
+
|
185 |
+
# Replace short+hyphenated text with a special token. Of the form:
|
186 |
+
# 'i-i-i-i-i-i-i-i-i-i-i-i', 'b-u-m-f-u-z-z-l-e', 'v-e-r-i-t-a-s-i-u-m', 'do-do-do-do-do'
|
187 |
+
text = re.sub(SHORT_HYPHENATED_REGEX,
|
188 |
+
CustomTokens.SHORT_HYPHENATED.value, text)
|
189 |
+
|
190 |
+
# Replace URLs with URL_TOKEN
|
191 |
+
URL_REGEX = r'(?:(?:http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.(?:[a-zA-Z]){2,6}(?:[a-zA-Z0-9\.\&\/\?\:@\-_=#%])*'
|
192 |
+
text = re.sub(URL_REGEX, CustomTokens.URL.value, text)
|
193 |
+
|
194 |
+
NUM_REGEX = r'(?:\d+,)*(?:\d*[.])?\d+'
|
195 |
+
|
196 |
+
# Encode specific numeric words
|
197 |
+
# Of the form: 12%, 12.34%
|
198 |
+
# Usually included in sponsorships
|
199 |
+
text = re.sub(f'{NUM_REGEX}%',
|
200 |
+
CustomTokens.NUMBER_PERCENTAGE.value, text)
|
201 |
+
|
202 |
+
# Normal numbers, should not have an effect on sponsorship
|
203 |
+
text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
|
204 |
+
|
205 |
+
# Replace profanity with special token
|
206 |
+
text = text.replace(CustomTokens.PROFANITY_RAW.value,
|
207 |
+
CustomTokens.PROFANITY.value)
|
208 |
+
text = text.replace(CustomTokens.PROFANITY_CONVERTED.value,
|
209 |
+
CustomTokens.PROFANITY.value)
|
210 |
+
|
211 |
+
return text.strip()
|
212 |
+
|
213 |
+
|
214 |
+
def remove_duplicate_sponsor_segments(sponsor_segments):
|
215 |
+
"""Choose the best sponsor segment if overlapping with others"""
|
216 |
+
|
217 |
+
# Algorithm based on SponsorBlock algorithm
|
218 |
+
# Find sponsors that are overlapping
|
219 |
+
similar = []
|
220 |
+
for i in sponsor_segments:
|
221 |
+
for j in sponsor_segments:
|
222 |
+
# Since we do pairwise, we only check one direction
|
223 |
+
if (j['start'] >= i['start'] and j['start'] <= i['end']):
|
224 |
+
similar.append([i, j])
|
225 |
+
|
226 |
+
# Within each group, choose the segment with the most votes.
|
227 |
+
processed = []
|
228 |
+
best = []
|
229 |
+
for i in similar:
|
230 |
+
if i in processed:
|
231 |
+
continue
|
232 |
+
group = i
|
233 |
+
for j in similar:
|
234 |
+
if j[0] in group or j[1] in group: # If either in, append both
|
235 |
+
group.append(j[0])
|
236 |
+
group.append(j[1])
|
237 |
+
processed.append(j)
|
238 |
+
|
239 |
+
best.append(max(group, key=lambda item: (
|
240 |
+
item['votes'], item['reputation'], item['views'])))
|
241 |
+
|
242 |
+
return best
|
243 |
+
|
244 |
+
|
245 |
+
@dataclass
|
246 |
+
class PreprocessArguments:
|
247 |
+
"""
|
248 |
+
Arguments pertaining to what data we are going to preprocess.
|
249 |
+
"""
|
250 |
+
update_database: bool = field(
|
251 |
+
default=False, metadata={'help': 'Download the raw database.'}
|
252 |
+
)
|
253 |
+
|
254 |
+
do_create: bool = field(
|
255 |
+
default=False, metadata={'help': 'Merge sponsor segments into single file'}
|
256 |
+
)
|
257 |
+
min_votes: int = field(
|
258 |
+
default=0, metadata={'help': 'Minimum number of votes'})
|
259 |
+
# Downvotes will make this negative.
|
260 |
+
# 1 = At least one positive vote
|
261 |
+
|
262 |
+
do_transcribe: bool = field(
|
263 |
+
default=False, metadata={'help': 'Get transcripts for videos'}
|
264 |
+
)
|
265 |
+
num_jobs: int = field(
|
266 |
+
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
267 |
+
|
268 |
+
overwrite: bool = field(
|
269 |
+
default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
|
270 |
+
)
|
271 |
+
|
272 |
+
do_generate: bool = field(
|
273 |
+
default=False, metadata={'help': 'Generate labelled data.'}
|
274 |
+
)
|
275 |
+
|
276 |
+
do_split: bool = field(
|
277 |
+
default=False, metadata={'help': 'Generate training, testing and validation data.'}
|
278 |
+
)
|
279 |
+
percentage_positive: float = field(
|
280 |
+
default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
|
281 |
+
|
282 |
+
train_split: float = field(
|
283 |
+
default=0.9, metadata={'help': 'Ratio of training data. Value between 0 and 1.'})
|
284 |
+
|
285 |
+
# TODO play around with ratios? lower test/validation split?
|
286 |
+
test_split: float = field(
|
287 |
+
default=0.05, metadata={'help': 'Ratio of testing data. Value between 0 and 1.'})
|
288 |
+
valid_split: float = field(
|
289 |
+
default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
|
290 |
+
|
291 |
+
skip_videos: int = field(default=None, metadata={
|
292 |
+
'help': 'Number of videos to skip. Set this to the latest video index to append to the current file'})
|
293 |
+
|
294 |
+
max_videos: int = field(default=None, metadata={
|
295 |
+
'help': 'Maximum number of videos to preprocess.'})
|
296 |
+
|
297 |
+
max_segments: int = field(default=None, metadata={
|
298 |
+
'help': 'Maximum number of segments to produce to preprocess.'})
|
299 |
+
|
300 |
+
raw_data_dir: Optional[str] = field(
|
301 |
+
default='raw',
|
302 |
+
metadata={
|
303 |
+
'help': 'Raw data directory'
|
304 |
+
},
|
305 |
+
)
|
306 |
+
raw_data_file: Optional[str] = field(
|
307 |
+
default='sponsorTimes.csv',
|
308 |
+
metadata={
|
309 |
+
'help': 'Raw data file'
|
310 |
+
},
|
311 |
+
)
|
312 |
+
|
313 |
+
min_wps: float = field(
|
314 |
+
default=0.4, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'})
|
315 |
+
# 0.1 ~ 1%
|
316 |
+
# 0.4 ~ 2.5%
|
317 |
+
# 0.9 ~ 5%
|
318 |
+
|
319 |
+
|
320 |
+
# Mirrors for database
|
321 |
+
MIRRORS = [
|
322 |
+
'https://sponsor.ajay.app/database/sponsorTimes.csv', # Latest
|
323 |
+
'https://sb-mirror.mchang.xyz/sponsorTimes.csv', # 5 minute delay
|
324 |
+
'https://sb.ltn.fi/database/sponsorTimes.csv', # 5 minute delay
|
325 |
+
]
|
326 |
+
# TODO only download latest (updates/changes)
|
327 |
+
|
328 |
+
|
329 |
+
def download_file(url, filename):
|
330 |
+
"""
|
331 |
+
Helper method handling downloading large files from `url` to `filename`.
|
332 |
+
|
333 |
+
Adapted from https://stackoverflow.com/a/42071418
|
334 |
+
"""
|
335 |
+
chunk_size = 1024
|
336 |
+
r = requests.get(url, stream=True)
|
337 |
+
total_bytes = int(r.headers['Content-Length'])
|
338 |
+
with open(filename, 'wb') as f, tqdm(unit='B', total=total_bytes) as progress:
|
339 |
+
for chunk in r.iter_content(chunk_size=chunk_size):
|
340 |
+
if chunk: # filter out keep-alive new chunks
|
341 |
+
progress.update(len(chunk))
|
342 |
+
f.write(chunk)
|
343 |
+
|
344 |
+
return total_bytes == os.path.getsize(filename)
|
345 |
+
|
346 |
+
|
347 |
+
@dataclass
|
348 |
+
class ProcessedArguments:
|
349 |
+
processed_dir: Optional[str] = field(
|
350 |
+
default='processed',
|
351 |
+
metadata={
|
352 |
+
'help': 'Processed data directory'
|
353 |
+
},
|
354 |
+
)
|
355 |
+
processed_file: Optional[str] = field(
|
356 |
+
default='final.json',
|
357 |
+
metadata={
|
358 |
+
'help': 'Processed data file'
|
359 |
+
},
|
360 |
+
)
|
361 |
+
|
362 |
+
|
363 |
+
def load_datasets(dataset_args):
|
364 |
+
print('Reading datasets')
|
365 |
+
data_files = {}
|
366 |
+
|
367 |
+
if dataset_args.train_file is not None:
|
368 |
+
data_files['train'] = os.path.join(
|
369 |
+
dataset_args.data_dir, dataset_args.train_file)
|
370 |
+
if dataset_args.validation_file is not None:
|
371 |
+
data_files['validation'] = os.path.join(
|
372 |
+
dataset_args.data_dir, dataset_args.validation_file)
|
373 |
+
if dataset_args.test_file is not None:
|
374 |
+
data_files['test'] = os.path.join(
|
375 |
+
dataset_args.data_dir, dataset_args.test_file)
|
376 |
+
|
377 |
+
return load_dataset('json', data_files=data_files)
|
378 |
+
|
379 |
+
|
380 |
+
@dataclass
|
381 |
+
class DatasetArguments:
|
382 |
+
data_dir: Optional[str] = field(
|
383 |
+
default='data',
|
384 |
+
metadata={
|
385 |
+
'help': 'The directory which stores train, test and/or validation data.'
|
386 |
+
},
|
387 |
+
)
|
388 |
+
|
389 |
+
train_file: Optional[str] = field(
|
390 |
+
default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
391 |
+
)
|
392 |
+
validation_file: Optional[str] = field(
|
393 |
+
default='valid.json',
|
394 |
+
metadata={
|
395 |
+
'help': 'An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines file).'
|
396 |
+
},
|
397 |
+
)
|
398 |
+
test_file: Optional[str] = field(
|
399 |
+
default='test.json',
|
400 |
+
metadata={
|
401 |
+
'help': 'An optional input test data file to evaluate the metrics (rouge) on (a jsonlines file).'
|
402 |
+
},
|
403 |
+
)
|
404 |
+
excess_file: Optional[str] = field(
|
405 |
+
default='excess.json',
|
406 |
+
metadata={
|
407 |
+
'help': 'The excess segments left after the split'
|
408 |
+
},
|
409 |
+
)
|
410 |
+
|
411 |
+
overwrite_cache: bool = field(
|
412 |
+
default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
|
413 |
+
)
|
414 |
+
|
415 |
+
positive_file: Optional[str] = field(
|
416 |
+
default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
|
417 |
+
)
|
418 |
+
negative_file: Optional[str] = field(
|
419 |
+
default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
|
420 |
+
)
|
421 |
+
|
422 |
+
def __post_init__(self):
|
423 |
+
# TODO check if train/validation datasets exist
|
424 |
+
if self.train_file is None and self.validation_file is None:
|
425 |
+
raise ValueError(
|
426 |
+
'Need either a dataset name or a training/validation file.')
|
427 |
+
|
428 |
+
|
429 |
+
def main():
|
430 |
+
# Responsible for getting transcrips using youtube_transcript_api,
|
431 |
+
# then labelling it according to SponsorBlock's API
|
432 |
+
|
433 |
+
logging.getLogger().setLevel(logging.INFO) # TODO make param
|
434 |
+
|
435 |
+
# Generate final.json from sponsorTimes.csv
|
436 |
+
hf_parser = HfArgumentParser((
|
437 |
+
PreprocessArguments,
|
438 |
+
ProcessedArguments,
|
439 |
+
DatasetArguments,
|
440 |
+
segment.SegmentationArguments,
|
441 |
+
ModelArguments,
|
442 |
+
GeneralArguments
|
443 |
+
))
|
444 |
+
preprocess_args, processed_args, dataset_args, segmentation_args, model_args, _ = hf_parser.parse_args_into_dataclasses()
|
445 |
+
|
446 |
+
raw_dataset_path = os.path.join(
|
447 |
+
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
448 |
+
|
449 |
+
def get_rows():
|
450 |
+
with open(raw_dataset_path, newline='') as csvfile:
|
451 |
+
reader = csv.DictReader(csvfile)
|
452 |
+
for line in reader:
|
453 |
+
if line['service'] != 'YouTube':
|
454 |
+
continue
|
455 |
+
|
456 |
+
# TODO add support for other categories and action types?
|
457 |
+
if line['category'] != 'sponsor':
|
458 |
+
continue
|
459 |
+
if line['actionType'] != 'skip':
|
460 |
+
continue
|
461 |
+
|
462 |
+
# Ignore hidden items
|
463 |
+
if line['hidden'] == '1' or line['shadowHidden'] == '1':
|
464 |
+
continue
|
465 |
+
|
466 |
+
if len(line['videoID']) != 11:
|
467 |
+
continue # Invalid youtube video ID
|
468 |
+
|
469 |
+
# Skip those that aren't highly voted
|
470 |
+
line['votes'] = int(line['votes'])
|
471 |
+
# incorrect_votes = int(line['incorrectVotes'])
|
472 |
+
|
473 |
+
if line['votes'] < preprocess_args.min_votes:
|
474 |
+
continue
|
475 |
+
|
476 |
+
yield line
|
477 |
+
|
478 |
+
if preprocess_args.update_database:
|
479 |
+
print('Updating database')
|
480 |
+
for mirror in MIRRORS:
|
481 |
+
print('Downloading from', mirror)
|
482 |
+
if download_file(mirror, raw_dataset_path):
|
483 |
+
break
|
484 |
+
print('Failed, trying next')
|
485 |
+
|
486 |
+
# 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
|
487 |
+
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
488 |
+
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
489 |
+
data_rows = None
|
490 |
+
if preprocess_args.do_transcribe:
|
491 |
+
print('Collecting videos')
|
492 |
+
video_ids = set()
|
493 |
+
data_rows = get_rows()
|
494 |
+
for row in data_rows:
|
495 |
+
video_ids.add(row['videoID'])
|
496 |
+
|
497 |
+
print('Start transcribing')
|
498 |
+
with tqdm(total=len(video_ids)) as progress:
|
499 |
+
def on_job_complete(job):
|
500 |
+
progress.set_description(f'Processed {job.video_id}')
|
501 |
+
progress.update()
|
502 |
+
|
503 |
+
pool = InterruptibleThreadPool(
|
504 |
+
preprocess_args.num_jobs, on_job_complete=on_job_complete)
|
505 |
+
|
506 |
+
print('Adding jobs to pool')
|
507 |
+
for video_id in video_ids:
|
508 |
+
job = Job(get_words, video_id)
|
509 |
+
job.video_id = video_id
|
510 |
+
pool.add_job(job)
|
511 |
+
|
512 |
+
print('Start processing')
|
513 |
+
pool.run()
|
514 |
+
|
515 |
+
print('Finished transcribing')
|
516 |
+
|
517 |
+
final_path = os.path.join(
|
518 |
+
processed_args.processed_dir, processed_args.processed_file)
|
519 |
+
|
520 |
+
if os.path.exists(final_path) and not preprocess_args.do_create:
|
521 |
+
logging.info(f'{final_path} exists, opening file')
|
522 |
+
with open(final_path) as fp:
|
523 |
+
final_data = json.load(fp)
|
524 |
+
else:
|
525 |
+
print('Create final data')
|
526 |
+
|
527 |
+
final_data = {}
|
528 |
+
|
529 |
+
if data_rows is None:
|
530 |
+
data_rows = get_rows()
|
531 |
+
|
532 |
+
# TODO add progress bar
|
533 |
+
# TODO parallelise?
|
534 |
+
for line in data_rows:
|
535 |
+
video_id = line['videoID']
|
536 |
+
|
537 |
+
if video_id not in final_data:
|
538 |
+
final_data[video_id] = []
|
539 |
+
|
540 |
+
segment_start = float(line['startTime'])
|
541 |
+
segment_end = float(line['endTime'])
|
542 |
+
|
543 |
+
video_words = get_words(video_id, process=True)
|
544 |
+
segment_words = segment.extract_segment(
|
545 |
+
video_words, segment_start, segment_end)
|
546 |
+
|
547 |
+
if len(segment_words) <= 1:
|
548 |
+
continue # Useless to add segment since no words
|
549 |
+
|
550 |
+
# duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0])
|
551 |
+
duration = segment_end - segment_start
|
552 |
+
wps = len(segment_words)/duration if duration > 0 else 0
|
553 |
+
|
554 |
+
if wps < preprocess_args.min_wps:
|
555 |
+
print('bad segment in', video_id, '| wps =', wps)
|
556 |
+
continue
|
557 |
+
|
558 |
+
final_data[video_id].append({
|
559 |
+
'start': segment_start,
|
560 |
+
'end': segment_end,
|
561 |
+
'votes': line['votes'],
|
562 |
+
'locked': line['locked'] == '1',
|
563 |
+
'views': line['views'],
|
564 |
+
'reputation': line['reputation'],
|
565 |
+
'category': line['category'],
|
566 |
+
'action': line['actionType'],
|
567 |
+
'uuid': line['UUID'],
|
568 |
+
})
|
569 |
+
|
570 |
+
# Remove duplicate sponsor segments by choosing best (most votes)
|
571 |
+
for key in final_data:
|
572 |
+
final_data[key] = remove_duplicate_sponsor_segments(
|
573 |
+
final_data[key])
|
574 |
+
|
575 |
+
# Save data
|
576 |
+
with open(final_path, 'w') as fp:
|
577 |
+
json.dump(final_data, fp)
|
578 |
+
|
579 |
+
# final_data = preprocess(
|
580 |
+
# raw_dataset_path, final_path, preprocess_args.min_votes)
|
581 |
+
# # TODO save metadata in final.json?
|
582 |
+
|
583 |
+
logging.info(f'Found {len(final_data)} videos')
|
584 |
+
|
585 |
+
# TODO shuffle final_data
|
586 |
+
|
587 |
+
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
588 |
+
# TODO use overwrite param
|
589 |
+
|
590 |
+
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
591 |
+
|
592 |
+
positive_file = os.path.join(
|
593 |
+
dataset_args.data_dir, dataset_args.positive_file)
|
594 |
+
negative_file = os.path.join(
|
595 |
+
dataset_args.data_dir, dataset_args.negative_file)
|
596 |
+
|
597 |
+
if preprocess_args.do_generate:
|
598 |
+
print('Generating')
|
599 |
+
from model import get_tokenizer
|
600 |
+
|
601 |
+
# max_videos=preprocess_args.max_videos,
|
602 |
+
# max_segments=preprocess_args.max_segments,
|
603 |
+
# , max_videos, max_segments
|
604 |
+
|
605 |
+
tokenizer = get_tokenizer(model_args)
|
606 |
+
|
607 |
+
count_videos = 0
|
608 |
+
count_segments = 0 # TODO
|
609 |
+
|
610 |
+
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
611 |
+
|
612 |
+
get_all = preprocess_args.max_videos is None
|
613 |
+
if get_all:
|
614 |
+
total = len(final_data)
|
615 |
+
else:
|
616 |
+
total = preprocess_args.max_videos
|
617 |
+
|
618 |
+
index = 0
|
619 |
+
data = final_data.items()
|
620 |
+
if preprocess_args.skip_videos is not None:
|
621 |
+
print('Skipping first', preprocess_args.skip_videos, 'videos')
|
622 |
+
data = itertools.islice(data, preprocess_args.skip_videos, None)
|
623 |
+
index = preprocess_args.skip_videos
|
624 |
+
|
625 |
+
if get_all:
|
626 |
+
total = max(0, total - preprocess_args.skip_videos)
|
627 |
+
else:
|
628 |
+
total = min(len(final_data) -
|
629 |
+
preprocess_args.skip_videos, total)
|
630 |
+
|
631 |
+
with open(positive_file, write_mode, encoding='utf-8') as positive, \
|
632 |
+
open(negative_file, write_mode, encoding='utf-8') as negative, \
|
633 |
+
tqdm(total=total) as progress:
|
634 |
+
|
635 |
+
for video_id, sponsor_segments in data:
|
636 |
+
index += 1 # TODO FIX index + incrementing
|
637 |
+
progress.set_description(f'Processing {video_id}')
|
638 |
+
|
639 |
+
if get_all:
|
640 |
+
progress.update()
|
641 |
+
elif count_videos >= preprocess_args.max_videos:
|
642 |
+
break
|
643 |
+
|
644 |
+
words = get_words(video_id, False)
|
645 |
+
if not words:
|
646 |
+
continue
|
647 |
+
|
648 |
+
num_words = len(words)
|
649 |
+
if num_words <= 1:
|
650 |
+
continue
|
651 |
+
|
652 |
+
# TODO only count words that aren't [Music], [Applause], etc.
|
653 |
+
|
654 |
+
segments = segment.generate_labelled_segments(
|
655 |
+
words, tokenizer, segmentation_args, sponsor_segments)
|
656 |
+
|
657 |
+
if not segments:
|
658 |
+
continue
|
659 |
+
|
660 |
+
count_videos += 1
|
661 |
+
if not get_all:
|
662 |
+
progress.update()
|
663 |
+
|
664 |
+
for seg in segments:
|
665 |
+
|
666 |
+
segment_text = ' '.join((x['text'] for x in seg))
|
667 |
+
|
668 |
+
extracted_text = ''
|
669 |
+
for p in extract_sponsors(seg):
|
670 |
+
p_text = ' '.join(p)
|
671 |
+
extracted_text += f'{CustomTokens.START_SPONSOR.value} {p_text} {CustomTokens.END_SPONSOR.value}. '
|
672 |
+
|
673 |
+
duration = segment.word_end(
|
674 |
+
seg[-1]) - segment.word_start(seg[0])
|
675 |
+
wps = len(seg)/duration if duration > 0 else 0
|
676 |
+
# Ignore segments with "not enough words" in the transcript
|
677 |
+
if wps < preprocess_args.min_wps:
|
678 |
+
continue
|
679 |
+
|
680 |
+
d = {
|
681 |
+
'video_index': index,
|
682 |
+
'video_id': video_id,
|
683 |
+
'text': clean_text(segment_text),
|
684 |
+
'words_per_second': wps,
|
685 |
+
}
|
686 |
+
|
687 |
+
d['sponsor'] = bool(extracted_text)
|
688 |
+
d['extracted'] = clean_text(
|
689 |
+
extracted_text) if d['sponsor'] else CustomTokens.NO_SPONSOR.value
|
690 |
+
|
691 |
+
print(json.dumps(d), file=(
|
692 |
+
positive if d['sponsor'] else negative))
|
693 |
+
|
694 |
+
if preprocess_args.do_split:
|
695 |
+
print('Splitting')
|
696 |
+
print('Read files')
|
697 |
+
|
698 |
+
with open(positive_file, encoding='utf-8') as positive:
|
699 |
+
sponsors = positive.readlines()
|
700 |
+
|
701 |
+
with open(negative_file, encoding='utf-8') as negative:
|
702 |
+
non_sponsors = negative.readlines()
|
703 |
+
|
704 |
+
print('Shuffle')
|
705 |
+
random.shuffle(sponsors)
|
706 |
+
random.shuffle(non_sponsors)
|
707 |
+
|
708 |
+
print('Calculate ratios')
|
709 |
+
# Ensure correct ratio of positive to negative segments
|
710 |
+
percentage_negative = 1 - preprocess_args.percentage_positive
|
711 |
+
|
712 |
+
if preprocess_args.percentage_positive * len(sponsors) > len(non_sponsors):
|
713 |
+
# Negative is limiting
|
714 |
+
z = int(preprocess_args.percentage_positive /
|
715 |
+
percentage_negative * len(non_sponsors))
|
716 |
+
|
717 |
+
excess = sponsors[z:]
|
718 |
+
sponsors = sponsors[:z]
|
719 |
+
|
720 |
+
else:
|
721 |
+
# Positive is limiting
|
722 |
+
z = int(percentage_negative /
|
723 |
+
preprocess_args.percentage_positive * len(sponsors))
|
724 |
+
|
725 |
+
excess = non_sponsors[z:]
|
726 |
+
non_sponsors = non_sponsors[:z]
|
727 |
+
|
728 |
+
print('Join')
|
729 |
+
all_labelled_segments = sponsors + non_sponsors
|
730 |
+
|
731 |
+
random.shuffle(all_labelled_segments)
|
732 |
+
|
733 |
+
print('Split')
|
734 |
+
ratios = [preprocess_args.train_split,
|
735 |
+
preprocess_args.test_split,
|
736 |
+
preprocess_args.valid_split]
|
737 |
+
|
738 |
+
train_data, test_data, valid_data = split(
|
739 |
+
all_labelled_segments, ratios)
|
740 |
+
|
741 |
+
splits = {
|
742 |
+
dataset_args.train_file: train_data,
|
743 |
+
dataset_args.test_file: test_data,
|
744 |
+
dataset_args.validation_file: valid_data
|
745 |
+
}
|
746 |
+
|
747 |
+
# Output training, testing and validation data
|
748 |
+
for name, items in splits.items():
|
749 |
+
outfile = os.path.join(dataset_args.data_dir, name)
|
750 |
+
if not os.path.exists(outfile) or preprocess_args.overwrite:
|
751 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
752 |
+
fp.writelines(items)
|
753 |
+
else:
|
754 |
+
print('Skipping', name)
|
755 |
+
|
756 |
+
print('Write')
|
757 |
+
# Save excess items
|
758 |
+
excess_path = os.path.join(
|
759 |
+
dataset_args.data_dir, dataset_args.excess_file)
|
760 |
+
if not os.path.exists(excess_path) or preprocess_args.overwrite:
|
761 |
+
with open(excess_path, 'w', encoding='utf-8') as fp:
|
762 |
+
fp.writelines(excess)
|
763 |
+
else:
|
764 |
+
print('Skipping', dataset_args.excess_file)
|
765 |
+
|
766 |
+
print('Finished splitting:', len(sponsors),
|
767 |
+
'sponsors,', len(non_sponsors), 'non sponsors')
|
768 |
+
|
769 |
+
|
770 |
+
def split(arr, ratios):
|
771 |
+
"""Split array according to ratios. Sum of ratios should be less than 1"""
|
772 |
+
|
773 |
+
to_return = []
|
774 |
+
|
775 |
+
cumulative_sum = 0
|
776 |
+
for r in ratios:
|
777 |
+
current = cumulative_sum
|
778 |
+
|
779 |
+
cumulative_sum += r * len(arr)
|
780 |
+
to_return.append(arr[int(current):int(cumulative_sum)])
|
781 |
+
|
782 |
+
return to_return
|
783 |
+
|
784 |
+
|
785 |
+
if __name__ == '__main__':
|
786 |
+
main()
|
src/segment.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import preprocess
|
2 |
+
from shared import CustomTokens
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class SegmentationArguments:
|
8 |
+
pause_threshold: int = field(default=2, metadata={
|
9 |
+
'help': 'When the time between words is greater than pause threshold, force into a new segment'})
|
10 |
+
|
11 |
+
|
12 |
+
# WORDS TO ALWAYS HAVE ON THEIR OWN
|
13 |
+
# always_split_re = re.compile(r'\[\w+\]')
|
14 |
+
# e.g., [Laughter], [Applause], [Music]
|
15 |
+
always_split = [
|
16 |
+
CustomTokens.MUSIC.value,
|
17 |
+
CustomTokens.APPLAUSE.value,
|
18 |
+
CustomTokens.LAUGHTER.value
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
def get_overlapping_chunks_of_tokens(tokens, size, overlap):
|
23 |
+
for i in range(0, len(tokens), size-overlap+1):
|
24 |
+
yield tokens[i:i+size]
|
25 |
+
|
26 |
+
|
27 |
+
# Generate up to max_tokens - SAFETY_TOKENS
|
28 |
+
SAFETY_TOKENS = 8
|
29 |
+
|
30 |
+
|
31 |
+
# TODO play around with this?
|
32 |
+
OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
|
33 |
+
|
34 |
+
|
35 |
+
def add_labels_to_words(words, sponsor_segments):
|
36 |
+
|
37 |
+
# TODO binary search
|
38 |
+
for word in words:
|
39 |
+
word['sponsor'] = False
|
40 |
+
for sponsor_segment in sponsor_segments:
|
41 |
+
if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
|
42 |
+
word['sponsor'] = True
|
43 |
+
|
44 |
+
# TODO use extract_segment with mapping function?
|
45 |
+
# TODO remove sponsor segments that contain mostly empty space?
|
46 |
+
|
47 |
+
return words
|
48 |
+
|
49 |
+
|
50 |
+
def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
|
51 |
+
segments = generate_segments(words, tokenizer, segmentation_args)
|
52 |
+
|
53 |
+
labelled_segments = list(
|
54 |
+
map(lambda x: add_labels_to_words(x, sponsor_segments), segments))
|
55 |
+
|
56 |
+
return labelled_segments
|
57 |
+
|
58 |
+
|
59 |
+
def word_start(word):
|
60 |
+
return word['start']
|
61 |
+
|
62 |
+
|
63 |
+
def word_end(word):
|
64 |
+
return word.get('end', word['start'])
|
65 |
+
|
66 |
+
|
67 |
+
def generate_segments(words, tokenizer, segmentation_args):
|
68 |
+
first_pass_segments = []
|
69 |
+
|
70 |
+
for index, word in enumerate(words):
|
71 |
+
# Get length of tokenized word
|
72 |
+
cleaned = preprocess.clean_text(word['text'])
|
73 |
+
word['num_tokens'] = len(
|
74 |
+
tokenizer(cleaned, add_special_tokens=False, truncation=True).input_ids)
|
75 |
+
|
76 |
+
add_new_segment = index == 0
|
77 |
+
if not add_new_segment:
|
78 |
+
|
79 |
+
if word['text'] in always_split or words[index-1]['text'] in always_split:
|
80 |
+
add_new_segment = True
|
81 |
+
|
82 |
+
# Pause too small, do not split
|
83 |
+
elif word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
84 |
+
add_new_segment = True
|
85 |
+
|
86 |
+
if add_new_segment: # New segment
|
87 |
+
first_pass_segments.append([word])
|
88 |
+
|
89 |
+
else: # Add to current segment
|
90 |
+
first_pass_segments[-1].append(word)
|
91 |
+
|
92 |
+
max_q_size = tokenizer.model_max_length - SAFETY_TOKENS
|
93 |
+
|
94 |
+
buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
|
95 |
+
|
96 |
+
# In second pass, we split those segments if too big
|
97 |
+
second_pass_segments = []
|
98 |
+
for segment in first_pass_segments:
|
99 |
+
current_segment_num_tokens = 0
|
100 |
+
current_segment = []
|
101 |
+
for word in segment:
|
102 |
+
if current_segment_num_tokens + word['num_tokens'] < max_q_size:
|
103 |
+
# Can add tokens to current segment
|
104 |
+
current_segment.append(word)
|
105 |
+
current_segment_num_tokens += word['num_tokens']
|
106 |
+
else:
|
107 |
+
# Adding this token would make it have too many tokens
|
108 |
+
# We save this batch and create new
|
109 |
+
second_pass_segments.append(current_segment.copy())
|
110 |
+
|
111 |
+
current_segment.append(word)
|
112 |
+
current_segment_num_tokens += word['num_tokens']
|
113 |
+
|
114 |
+
while current_segment_num_tokens > buffer_size and current_segment:
|
115 |
+
first_word = current_segment.pop(0)
|
116 |
+
current_segment_num_tokens -= first_word['num_tokens']
|
117 |
+
|
118 |
+
if current_segment:
|
119 |
+
second_pass_segments.append(current_segment.copy())
|
120 |
+
|
121 |
+
return second_pass_segments
|
122 |
+
|
123 |
+
|
124 |
+
def extract_segment(words, start, end, map_function=None):
|
125 |
+
"""Extract a segment of words that are between (inclusive) the start and end points"""
|
126 |
+
segment_words = []
|
127 |
+
|
128 |
+
if start > end:
|
129 |
+
return segment_words
|
130 |
+
|
131 |
+
# TODO change to binary search
|
132 |
+
for w in words: # Assumes words are sorted
|
133 |
+
if word_end(w) < start:
|
134 |
+
continue # Ignore
|
135 |
+
if word_start(w) > end:
|
136 |
+
break # Done with range
|
137 |
+
if map_function is not None and callable(map_function):
|
138 |
+
w = map_function(w)
|
139 |
+
|
140 |
+
segment_words.append(w)
|
141 |
+
|
142 |
+
return segment_words
|
src/shared.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from time import time_ns
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from typing import Optional
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from enum import Enum
|
9 |
+
|
10 |
+
|
11 |
+
class CustomTokens(Enum):
|
12 |
+
URL = 'URL_TOKEN'
|
13 |
+
HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
|
14 |
+
NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
|
15 |
+
NUMBER = 'NUMBER_TOKEN'
|
16 |
+
|
17 |
+
START_SPONSOR = 'START_SPONSOR'
|
18 |
+
END_SPONSOR = 'END_SPONSOR'
|
19 |
+
NO_SPONSOR = 'NO_SPONSOR_FOUND'
|
20 |
+
|
21 |
+
SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
|
22 |
+
LONG_WORD = 'LONG_WORD_TOKEN'
|
23 |
+
|
24 |
+
# Custom YouTube tokens
|
25 |
+
MUSIC = '[Music]'
|
26 |
+
APPLAUSE = '[Applause]'
|
27 |
+
LAUGHTER = '[Laughter]'
|
28 |
+
|
29 |
+
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
30 |
+
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
31 |
+
PROFANITY = 'PROFANITY_TOKEN'
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def custom_tokens(cls):
|
35 |
+
return [e.value for e in cls]
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def add_custom_tokens(cls, tokenizer):
|
39 |
+
tokenizer.add_tokens(cls.custom_tokens())
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class OutputArguments:
|
44 |
+
|
45 |
+
output_dir: str = field(
|
46 |
+
default='out',
|
47 |
+
metadata={
|
48 |
+
'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
|
49 |
+
},
|
50 |
+
)
|
51 |
+
checkpoint: Optional[str] = field(
|
52 |
+
default=None,
|
53 |
+
metadata={
|
54 |
+
'help': 'Choose the checkpoint/model to train from or test with. Defaults to the latest checkpoint found in `output_dir`.'
|
55 |
+
},
|
56 |
+
)
|
57 |
+
models_dir: str = field(
|
58 |
+
default='models',
|
59 |
+
metadata={
|
60 |
+
'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
|
61 |
+
},
|
62 |
+
)
|
63 |
+
# classifier_dir: str = field(
|
64 |
+
# default='out',
|
65 |
+
# metadata={
|
66 |
+
# 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
|
67 |
+
# },
|
68 |
+
# )
|
69 |
+
|
70 |
+
|
71 |
+
def seed_factory():
|
72 |
+
return time_ns() % (2**32 - 1)
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class GeneralArguments:
|
77 |
+
seed: Optional[int] = field(default_factory=seed_factory, metadata={
|
78 |
+
'help': 'Set seed for deterministic training and testing. By default, it uses the current time (results in essentially random results).'
|
79 |
+
})
|
80 |
+
|
81 |
+
def __post_init__(self):
|
82 |
+
random.seed(self.seed)
|
83 |
+
np.random.seed(self.seed)
|
84 |
+
torch.manual_seed(self.seed)
|
85 |
+
torch.cuda.manual_seed_all(self.seed)
|
86 |
+
|
87 |
+
|
88 |
+
def device():
|
89 |
+
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
90 |
+
|
91 |
+
|
92 |
+
def reset():
|
93 |
+
torch.clear_autocast_cache()
|
94 |
+
torch.cuda.empty_cache()
|
95 |
+
gc.collect()
|
96 |
+
print(torch.cuda.memory_summary(device=None, abbreviated=False))
|
src/train.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from preprocess import load_datasets, DatasetArguments
|
2 |
+
from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
|
3 |
+
from shared import device
|
4 |
+
from shared import GeneralArguments, OutputArguments
|
5 |
+
from model import ModelArguments
|
6 |
+
import transformers
|
7 |
+
import logging
|
8 |
+
from model import get_model, get_tokenizer
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from typing import Optional
|
14 |
+
import datasets
|
15 |
+
import pickle
|
16 |
+
from transformers import (
|
17 |
+
DataCollatorForSeq2Seq,
|
18 |
+
HfArgumentParser,
|
19 |
+
Seq2SeqTrainer,
|
20 |
+
Seq2SeqTrainingArguments
|
21 |
+
)
|
22 |
+
from transformers.trainer_utils import get_last_checkpoint
|
23 |
+
from transformers.utils import check_min_version
|
24 |
+
from transformers.utils.versions import require_version
|
25 |
+
from sklearn.linear_model import LogisticRegression
|
26 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
27 |
+
|
28 |
+
import re
|
29 |
+
|
30 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
31 |
+
check_min_version('4.13.0.dev0')
|
32 |
+
require_version('datasets>=1.8.0',
|
33 |
+
'To fix: pip install -r requirements.txt')
|
34 |
+
|
35 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
# Setup logging
|
41 |
+
logging.basicConfig(
|
42 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
43 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
44 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class DataTrainingArguments:
|
50 |
+
"""
|
51 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
52 |
+
"""
|
53 |
+
|
54 |
+
preprocessing_num_workers: Optional[int] = field(
|
55 |
+
default=None,
|
56 |
+
metadata={'help': 'The number of processes to use for the preprocessing.'},
|
57 |
+
)
|
58 |
+
# https://github.com/huggingface/transformers/issues/5204
|
59 |
+
max_source_length: Optional[int] = field(
|
60 |
+
default=512,
|
61 |
+
metadata={
|
62 |
+
'help': 'The maximum total input sequence length after tokenization. Sequences longer '
|
63 |
+
'than this will be truncated, sequences shorter will be padded.'
|
64 |
+
},
|
65 |
+
)
|
66 |
+
max_target_length: Optional[int] = field(
|
67 |
+
default=512,
|
68 |
+
metadata={
|
69 |
+
'help': 'The maximum total sequence length for target text after tokenization. Sequences longer '
|
70 |
+
'than this will be truncated, sequences shorter will be padded.'
|
71 |
+
},
|
72 |
+
)
|
73 |
+
val_max_target_length: Optional[int] = field(
|
74 |
+
default=None,
|
75 |
+
metadata={
|
76 |
+
'help': 'The maximum total sequence length for validation target text after tokenization. Sequences longer '
|
77 |
+
'than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`.'
|
78 |
+
'This argument is also used to override the ``max_length`` param of ``model.generate``, which is used '
|
79 |
+
'during ``evaluate`` and ``predict``.'
|
80 |
+
},
|
81 |
+
)
|
82 |
+
pad_to_max_length: bool = field(
|
83 |
+
default=False,
|
84 |
+
metadata={
|
85 |
+
'help': 'Whether to pad all samples to model maximum sentence length. '
|
86 |
+
'If False, will pad the samples dynamically when batching to the maximum length in the batch. More '
|
87 |
+
'efficient on GPU but very bad for TPU.'
|
88 |
+
},
|
89 |
+
)
|
90 |
+
max_train_samples: Optional[int] = field(
|
91 |
+
default=None,
|
92 |
+
metadata={
|
93 |
+
'help': 'For debugging purposes or quicker training, truncate the number of training examples to this value if set.'
|
94 |
+
},
|
95 |
+
)
|
96 |
+
max_eval_samples: Optional[int] = field(
|
97 |
+
default=None,
|
98 |
+
metadata={
|
99 |
+
'help': 'For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set.'
|
100 |
+
},
|
101 |
+
)
|
102 |
+
max_predict_samples: Optional[int] = field(
|
103 |
+
default=None,
|
104 |
+
metadata={
|
105 |
+
'help': 'For debugging purposes or quicker training, truncate the number of prediction examples to this value if set.'
|
106 |
+
},
|
107 |
+
)
|
108 |
+
num_beams: Optional[int] = field(
|
109 |
+
default=None,
|
110 |
+
metadata={
|
111 |
+
'help': 'Number of beams to use for evaluation. This argument will be passed to ``model.generate``, '
|
112 |
+
'which is used during ``evaluate`` and ``predict``.'
|
113 |
+
},
|
114 |
+
)
|
115 |
+
ignore_pad_token_for_loss: bool = field(
|
116 |
+
default=True,
|
117 |
+
metadata={
|
118 |
+
'help': 'Whether to ignore the tokens corresponding to padded labels in the loss computation or not.'
|
119 |
+
},
|
120 |
+
)
|
121 |
+
source_prefix: Optional[str] = field(
|
122 |
+
default=DEFAULT_TOKEN_PREFIX, metadata={
|
123 |
+
'help': 'A prefix to add before every source text (useful for T5 models).'}
|
124 |
+
)
|
125 |
+
|
126 |
+
# TODO add vectorizer params
|
127 |
+
|
128 |
+
def __post_init__(self):
|
129 |
+
if self.val_max_target_length is None:
|
130 |
+
self.val_max_target_length = self.max_target_length
|
131 |
+
|
132 |
+
|
133 |
+
@dataclass
|
134 |
+
class SequenceTrainingArguments(OutputArguments, Seq2SeqTrainingArguments):
|
135 |
+
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
136 |
+
|
137 |
+
num_train_epochs: float = field(
|
138 |
+
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
139 |
+
|
140 |
+
save_steps: int = field(default=2500, metadata={
|
141 |
+
'help': 'Save checkpoint every X updates steps.'})
|
142 |
+
eval_steps: int = field(default=2500, metadata={
|
143 |
+
'help': 'Run an evaluation every X steps.'})
|
144 |
+
logging_steps: int = field(default=2500, metadata={
|
145 |
+
'help': 'Log every X updates steps.'})
|
146 |
+
|
147 |
+
skip_train_transformer: bool = field(default=False, metadata={
|
148 |
+
'help': 'Whether to skip training the transformer.'})
|
149 |
+
train_classifier: bool = field(default=False, metadata={
|
150 |
+
'help': 'Whether to run training on the 2nd phase (classifier).'})
|
151 |
+
|
152 |
+
# do_eval: bool = field(default=False, metadata={
|
153 |
+
# 'help': 'Whether to run eval on the dev set.'})
|
154 |
+
do_predict: bool = field(default=False, metadata={
|
155 |
+
'help': 'Whether to run predictions on the test set.'})
|
156 |
+
|
157 |
+
per_device_train_batch_size: int = field(
|
158 |
+
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
|
159 |
+
)
|
160 |
+
per_device_eval_batch_size: int = field(
|
161 |
+
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
|
162 |
+
)
|
163 |
+
|
164 |
+
# report_to: Optional[List[str]] = field(
|
165 |
+
# default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
166 |
+
# )
|
167 |
+
evaluation_strategy: str = field(
|
168 |
+
default='steps',
|
169 |
+
metadata={
|
170 |
+
'help': 'The evaluation strategy to use.',
|
171 |
+
'choices': ['no', 'steps', 'epoch']
|
172 |
+
},
|
173 |
+
)
|
174 |
+
|
175 |
+
# evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
176 |
+
# The evaluation strategy to adopt during training. Possible values are:
|
177 |
+
|
178 |
+
# * :obj:`"no"`: No evaluation is done during training.
|
179 |
+
# * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
|
180 |
+
# * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
|
181 |
+
|
182 |
+
|
183 |
+
def main():
|
184 |
+
# reset()
|
185 |
+
|
186 |
+
# See all possible arguments in src/transformers/training_args.py
|
187 |
+
# or by passing the --help flag to this script.
|
188 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
189 |
+
|
190 |
+
hf_parser = HfArgumentParser((
|
191 |
+
ModelArguments,
|
192 |
+
DatasetArguments,
|
193 |
+
DataTrainingArguments,
|
194 |
+
SequenceTrainingArguments,
|
195 |
+
ClassifierArguments
|
196 |
+
))
|
197 |
+
model_args, dataset_args, data_training_args, training_args, classifier_args = hf_parser.parse_args_into_dataclasses()
|
198 |
+
|
199 |
+
log_level = training_args.get_process_log_level()
|
200 |
+
logger.setLevel(log_level)
|
201 |
+
datasets.utils.logging.set_verbosity(log_level)
|
202 |
+
transformers.utils.logging.set_verbosity(log_level)
|
203 |
+
transformers.utils.logging.enable_default_handler()
|
204 |
+
transformers.utils.logging.enable_explicit_format()
|
205 |
+
|
206 |
+
# Set seed before initializing model.
|
207 |
+
# set_seed(training_args.seed)
|
208 |
+
|
209 |
+
# Log on each process the small summary:
|
210 |
+
logger.warning(
|
211 |
+
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
|
212 |
+
+ f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
|
213 |
+
)
|
214 |
+
logger.info(f'Training/evaluation parameters {training_args}')
|
215 |
+
|
216 |
+
# FP16 https://github.com/huggingface/transformers/issues/9295
|
217 |
+
|
218 |
+
# Works:
|
219 |
+
# https://huggingface.co/docs/transformers/model_doc/t5v1.1
|
220 |
+
# google/t5-v1_1-small
|
221 |
+
# google/t5-v1_1-base
|
222 |
+
# google/t5-v1_1-large
|
223 |
+
# google/t5-v1_1-xl
|
224 |
+
# google/t5-v1_1-xxl
|
225 |
+
|
226 |
+
# https://huggingface.co/docs/transformers/model_doc/t5
|
227 |
+
# t5-small
|
228 |
+
# t5-base
|
229 |
+
# t5-large
|
230 |
+
# t5-3b
|
231 |
+
# t5-11b
|
232 |
+
|
233 |
+
# allenai/led-base-16384 - https://github.com/huggingface/transformers/issues/9810
|
234 |
+
|
235 |
+
# Further work:
|
236 |
+
# Multilingual- https://huggingface.co/docs/transformers/model_doc/mt5
|
237 |
+
|
238 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
239 |
+
# download the dataset.
|
240 |
+
if training_args.skip_train_transformer and not training_args.train_classifier:
|
241 |
+
print('Nothing to do. Exiting')
|
242 |
+
return
|
243 |
+
|
244 |
+
raw_datasets = load_datasets(dataset_args)
|
245 |
+
# , cache_dir=model_args.cache_dir
|
246 |
+
|
247 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
248 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
249 |
+
|
250 |
+
if training_args.train_classifier:
|
251 |
+
print('Train classifier')
|
252 |
+
# 1. Vectorize raw data to pass into classifier
|
253 |
+
# CountVectorizer TfidfVectorizer
|
254 |
+
# TfidfVectorizer - better (comb of CountVectorizer)
|
255 |
+
vectorizer = TfidfVectorizer( # CountVectorizer
|
256 |
+
# lowercase=False,
|
257 |
+
# stop_words='english', # TODO optimise stop words?
|
258 |
+
# stop_words=stop_words,
|
259 |
+
|
260 |
+
ngram_range=(1, 2), # best so far
|
261 |
+
# max_features=8000 # remove for higher accuracy?
|
262 |
+
max_features=50000
|
263 |
+
# max_features=10000
|
264 |
+
)
|
265 |
+
|
266 |
+
train_test_data = {
|
267 |
+
'train': {
|
268 |
+
'X': [],
|
269 |
+
'y': []
|
270 |
+
},
|
271 |
+
'test': {
|
272 |
+
'X': [],
|
273 |
+
'y': []
|
274 |
+
}
|
275 |
+
}
|
276 |
+
|
277 |
+
print('Splitting')
|
278 |
+
for ds_type in train_test_data:
|
279 |
+
dataset = raw_datasets[ds_type]
|
280 |
+
|
281 |
+
for row in dataset:
|
282 |
+
|
283 |
+
# Get matches:
|
284 |
+
if row['sponsor']:
|
285 |
+
matches = re.findall(SPONSOR_MATCH_RE, row['extracted'])
|
286 |
+
else:
|
287 |
+
matches = [row['text']]
|
288 |
+
|
289 |
+
for match in matches:
|
290 |
+
train_test_data[ds_type]['X'].append(match)
|
291 |
+
train_test_data[ds_type]['y'].append(row['sponsor'])
|
292 |
+
|
293 |
+
print('Fitting')
|
294 |
+
_X_train = vectorizer.fit_transform(train_test_data['train']['X'])
|
295 |
+
_X_test = vectorizer.transform(train_test_data['test']['X'])
|
296 |
+
|
297 |
+
y_train = train_test_data['train']['y']
|
298 |
+
y_test = train_test_data['test']['y']
|
299 |
+
|
300 |
+
# 2. Create classifier
|
301 |
+
classifier = LogisticRegression(max_iter=500)
|
302 |
+
|
303 |
+
# 3. Fit data
|
304 |
+
print('fit classifier')
|
305 |
+
classifier.fit(_X_train, y_train)
|
306 |
+
|
307 |
+
# 4. Measure accuracy
|
308 |
+
accuracy = classifier.score(_X_test, y_test)
|
309 |
+
|
310 |
+
print(f'[LogisticRegression] Accuracy percent:',
|
311 |
+
round(accuracy*100, 3))
|
312 |
+
|
313 |
+
# 5. Save classifier and vectorizer
|
314 |
+
with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'wb') as fp:
|
315 |
+
pickle.dump(classifier, fp)
|
316 |
+
|
317 |
+
with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'wb') as fp:
|
318 |
+
pickle.dump(vectorizer, fp)
|
319 |
+
|
320 |
+
if not training_args.skip_train_transformer:
|
321 |
+
|
322 |
+
if data_training_args.source_prefix is None and 't5-' in model_args.model_name_or_path:
|
323 |
+
logger.warning(
|
324 |
+
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with `--source_prefix 'summarize: ' `"
|
325 |
+
)
|
326 |
+
|
327 |
+
# Detecting last checkpoint.
|
328 |
+
last_checkpoint = None
|
329 |
+
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
330 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
331 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
332 |
+
raise ValueError(
|
333 |
+
f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
|
334 |
+
)
|
335 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
336 |
+
logger.info(
|
337 |
+
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
|
338 |
+
)
|
339 |
+
|
340 |
+
# Load pretrained model and tokenizer
|
341 |
+
tokenizer = get_tokenizer(model_args)
|
342 |
+
model = get_model(model_args)
|
343 |
+
model.to(device())
|
344 |
+
model.resize_token_embeddings(len(tokenizer))
|
345 |
+
|
346 |
+
if model.config.decoder_start_token_id is None:
|
347 |
+
raise ValueError(
|
348 |
+
'Make sure that `config.decoder_start_token_id` is correctly defined')
|
349 |
+
|
350 |
+
if hasattr(model.config, 'max_position_embeddings') and model.config.max_position_embeddings < data_training_args.max_source_length:
|
351 |
+
if model_args.resize_position_embeddings is None:
|
352 |
+
logger.warning(
|
353 |
+
f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} to {data_training_args.max_source_length}."
|
354 |
+
)
|
355 |
+
model.resize_position_embeddings(
|
356 |
+
data_training_args.max_source_length)
|
357 |
+
|
358 |
+
elif model_args.resize_position_embeddings:
|
359 |
+
model.resize_position_embeddings(
|
360 |
+
data_training_args.max_source_length)
|
361 |
+
|
362 |
+
else:
|
363 |
+
raise ValueError(
|
364 |
+
f'`--max_source_length` is set to {data_training_args.max_source_length}, but the model only has {model.config.max_position_embeddings}'
|
365 |
+
f' position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically '
|
366 |
+
"resize the model's position encodings by passing `--resize_position_embeddings`."
|
367 |
+
)
|
368 |
+
|
369 |
+
# Preprocessing the datasets.
|
370 |
+
# We need to tokenize inputs and targets.
|
371 |
+
column_names = raw_datasets['train'].column_names
|
372 |
+
|
373 |
+
# Temporarily set max_target_length for training.
|
374 |
+
max_target_length = data_training_args.max_target_length
|
375 |
+
padding = 'max_length' if data_training_args.pad_to_max_length else False
|
376 |
+
|
377 |
+
if training_args.label_smoothing_factor > 0 and not hasattr(model, 'prepare_decoder_input_ids_from_labels'):
|
378 |
+
logger.warning(
|
379 |
+
'label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for'
|
380 |
+
f'`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory'
|
381 |
+
)
|
382 |
+
|
383 |
+
prefix = data_training_args.source_prefix if data_training_args.source_prefix is not None else ''
|
384 |
+
|
385 |
+
# https://github.com/huggingface/transformers/issues/5204
|
386 |
+
def preprocess_function(examples):
|
387 |
+
inputs = examples['text']
|
388 |
+
targets = examples['extracted']
|
389 |
+
inputs = [prefix + inp for inp in inputs]
|
390 |
+
model_inputs = tokenizer(
|
391 |
+
inputs, max_length=data_training_args.max_source_length, padding=padding, truncation=True)
|
392 |
+
|
393 |
+
# Setup the tokenizer for targets
|
394 |
+
with tokenizer.as_target_tokenizer():
|
395 |
+
labels = tokenizer(
|
396 |
+
targets, max_length=max_target_length, padding=padding, truncation=True)
|
397 |
+
|
398 |
+
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
399 |
+
# padding in the loss.
|
400 |
+
if padding == 'max_length' and data_training_args.ignore_pad_token_for_loss:
|
401 |
+
labels['input_ids'] = [
|
402 |
+
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels['input_ids']
|
403 |
+
]
|
404 |
+
model_inputs['labels'] = labels['input_ids']
|
405 |
+
|
406 |
+
return model_inputs
|
407 |
+
|
408 |
+
def prepare_dataset(dataset, desc):
|
409 |
+
return dataset.map(
|
410 |
+
preprocess_function,
|
411 |
+
batched=True,
|
412 |
+
num_proc=data_training_args.preprocessing_num_workers,
|
413 |
+
remove_columns=column_names,
|
414 |
+
load_from_cache_file=not dataset_args.overwrite_cache,
|
415 |
+
desc=desc, # tokenizing train dataset
|
416 |
+
)
|
417 |
+
# train_dataset # TODO shuffle?
|
418 |
+
|
419 |
+
# if training_args.do_train:
|
420 |
+
if 'train' not in raw_datasets: # TODO do checks above?
|
421 |
+
raise ValueError('Train dataset missing')
|
422 |
+
train_dataset = raw_datasets['train']
|
423 |
+
if data_training_args.max_train_samples is not None:
|
424 |
+
train_dataset = train_dataset.select(
|
425 |
+
range(data_training_args.max_train_samples))
|
426 |
+
with training_args.main_process_first(desc='train dataset map pre-processing'):
|
427 |
+
train_dataset = prepare_dataset(
|
428 |
+
train_dataset, desc='Running tokenizer on train dataset')
|
429 |
+
|
430 |
+
max_target_length = data_training_args.val_max_target_length
|
431 |
+
if 'validation' not in raw_datasets:
|
432 |
+
raise ValueError('Validation dataset missing')
|
433 |
+
eval_dataset = raw_datasets['validation']
|
434 |
+
if data_training_args.max_eval_samples is not None:
|
435 |
+
eval_dataset = eval_dataset.select(
|
436 |
+
range(data_training_args.max_eval_samples))
|
437 |
+
with training_args.main_process_first(desc='validation dataset map pre-processing'):
|
438 |
+
eval_dataset = prepare_dataset(
|
439 |
+
eval_dataset, desc='Running tokenizer on validation dataset')
|
440 |
+
|
441 |
+
if 'test' not in raw_datasets:
|
442 |
+
raise ValueError('Test dataset missing')
|
443 |
+
predict_dataset = raw_datasets['test']
|
444 |
+
if data_training_args.max_predict_samples is not None:
|
445 |
+
predict_dataset = predict_dataset.select(
|
446 |
+
range(data_training_args.max_predict_samples))
|
447 |
+
with training_args.main_process_first(desc='prediction dataset map pre-processing'):
|
448 |
+
predict_dataset = prepare_dataset(
|
449 |
+
predict_dataset, desc='Running tokenizer on prediction dataset')
|
450 |
+
|
451 |
+
# Data collator
|
452 |
+
label_pad_token_id = - \
|
453 |
+
100 if data_training_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
454 |
+
data_collator = DataCollatorForSeq2Seq(
|
455 |
+
tokenizer,
|
456 |
+
model=model,
|
457 |
+
label_pad_token_id=label_pad_token_id,
|
458 |
+
pad_to_multiple_of=8 if training_args.fp16 else None,
|
459 |
+
)
|
460 |
+
|
461 |
+
# Done processing datasets
|
462 |
+
|
463 |
+
# Initialize our Trainer
|
464 |
+
trainer = Seq2SeqTrainer(
|
465 |
+
model=model,
|
466 |
+
args=training_args,
|
467 |
+
train_dataset=train_dataset,
|
468 |
+
eval_dataset=eval_dataset,
|
469 |
+
tokenizer=tokenizer,
|
470 |
+
data_collator=data_collator,
|
471 |
+
)
|
472 |
+
|
473 |
+
# Training
|
474 |
+
checkpoint = None
|
475 |
+
if training_args.resume_from_checkpoint is not None:
|
476 |
+
checkpoint = training_args.resume_from_checkpoint
|
477 |
+
elif last_checkpoint is not None:
|
478 |
+
checkpoint = last_checkpoint
|
479 |
+
|
480 |
+
try:
|
481 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
482 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
483 |
+
except KeyboardInterrupt:
|
484 |
+
print('Saving model')
|
485 |
+
trainer.save_model(os.path.join(
|
486 |
+
training_args.output_dir, 'checkpoint-latest')) # TODO use dir
|
487 |
+
raise
|
488 |
+
|
489 |
+
metrics = train_result.metrics
|
490 |
+
max_train_samples = data_training_args.max_train_samples or len(
|
491 |
+
train_dataset)
|
492 |
+
metrics['train_samples'] = min(max_train_samples, len(train_dataset))
|
493 |
+
|
494 |
+
trainer.log_metrics('train', metrics)
|
495 |
+
trainer.save_metrics('train', metrics)
|
496 |
+
trainer.save_state()
|
497 |
+
|
498 |
+
kwargs = {'finetuned_from': model_args.model_name_or_path,
|
499 |
+
'tasks': 'summarization'}
|
500 |
+
|
501 |
+
if training_args.push_to_hub:
|
502 |
+
trainer.push_to_hub(**kwargs)
|
503 |
+
else:
|
504 |
+
trainer.create_model_card(**kwargs)
|
505 |
+
|
506 |
+
|
507 |
+
if __name__ == '__main__':
|
508 |
+
main()
|
src/utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
|
4 |
+
class Job:
|
5 |
+
def __init__(self, function, *args, **kwargs) -> None:
|
6 |
+
self.function = function
|
7 |
+
self.args = args
|
8 |
+
self.kwargs = kwargs
|
9 |
+
|
10 |
+
self.result = None
|
11 |
+
|
12 |
+
|
13 |
+
class InterruptibleThreadPool:
|
14 |
+
def __init__(self,
|
15 |
+
num_workers=None,
|
16 |
+
loop=None,
|
17 |
+
shutdown_message='\nAttempting graceful shutdown, press Ctrl+C again to exit...',
|
18 |
+
on_job_complete=None, # Useful for monitoring progress
|
19 |
+
raise_after_interrupt=False,
|
20 |
+
) -> None:
|
21 |
+
self.num_workers = os.cpu_count() if num_workers is None else num_workers
|
22 |
+
self.loop = asyncio.get_event_loop() if loop is None else loop
|
23 |
+
self.shutdown_message = shutdown_message
|
24 |
+
|
25 |
+
self.sem = asyncio.Semaphore(num_workers)
|
26 |
+
|
27 |
+
self.jobs = []
|
28 |
+
|
29 |
+
self.on_job_complete = on_job_complete
|
30 |
+
self.raise_after_interrupt = raise_after_interrupt
|
31 |
+
|
32 |
+
async def _sync_to_async(self, job):
|
33 |
+
async with self.sem: # Limit number of parallel tasks
|
34 |
+
job.result = await self.loop.run_in_executor(None, job.function, *job.args, **job.kwargs)
|
35 |
+
|
36 |
+
if callable(self.on_job_complete):
|
37 |
+
self.on_job_complete(job)
|
38 |
+
|
39 |
+
return job
|
40 |
+
|
41 |
+
def add_job(self, job):
|
42 |
+
self.jobs.append(job)
|
43 |
+
|
44 |
+
def run(self):
|
45 |
+
try:
|
46 |
+
tasks = [
|
47 |
+
# creating task starts coroutine
|
48 |
+
asyncio.ensure_future(self._sync_to_async(job))
|
49 |
+
for job in self.jobs
|
50 |
+
]
|
51 |
+
|
52 |
+
# https://stackoverflow.com/a/42097478
|
53 |
+
self.loop.run_until_complete(
|
54 |
+
asyncio.gather(*tasks, return_exceptions=True)
|
55 |
+
)
|
56 |
+
|
57 |
+
except KeyboardInterrupt:
|
58 |
+
# Optionally show a message if the shutdown may take a while
|
59 |
+
print(self.shutdown_message, flush=True)
|
60 |
+
|
61 |
+
# Do not show `asyncio.CancelledError` exceptions during shutdown
|
62 |
+
# (a lot of these may be generated, skip this if you prefer to see them)
|
63 |
+
def shutdown_exception_handler(loop, context):
|
64 |
+
if "exception" not in context \
|
65 |
+
or not isinstance(context["exception"], asyncio.CancelledError):
|
66 |
+
loop.default_exception_handler(context)
|
67 |
+
self.loop.set_exception_handler(shutdown_exception_handler)
|
68 |
+
|
69 |
+
# Handle shutdown gracefully by waiting for all tasks to be cancelled
|
70 |
+
cancelled_tasks = asyncio.gather(
|
71 |
+
*asyncio.all_tasks(loop=self.loop), loop=self.loop, return_exceptions=True)
|
72 |
+
cancelled_tasks.add_done_callback(lambda t: self.loop.stop())
|
73 |
+
cancelled_tasks.cancel()
|
74 |
+
|
75 |
+
# Keep the event loop running until it is either destroyed or all
|
76 |
+
# tasks have really terminated
|
77 |
+
while not cancelled_tasks.done() and not self.loop.is_closed():
|
78 |
+
self.loop.run_forever()
|
79 |
+
|
80 |
+
if self.raise_after_interrupt:
|
81 |
+
raise
|
82 |
+
finally:
|
83 |
+
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
84 |
+
self.loop.close()
|
85 |
+
|
86 |
+
return self.jobs
|