Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
b27b0d5
1
Parent(s):
bce5ce9
Improve preprocessing and segmentation
Browse files- src/preprocess.py +140 -111
- src/segment.py +44 -56
src/preprocess.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from utils import jaccard
|
2 |
from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
|
3 |
from functools import lru_cache
|
@@ -113,19 +114,26 @@ def list_transcripts(video_id):
|
|
113 |
return YouTubeTranscriptApi.list_transcripts(video_id)
|
114 |
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
@lru_cache(maxsize=16)
|
117 |
-
def get_words(video_id, process=True,
|
118 |
"""Get parsed video transcript with caching system
|
119 |
returns None if not processed yet and process is False
|
120 |
"""
|
121 |
-
get_manual_if_fail = fallback and transcript_type == 'auto'
|
122 |
transcript_path = os.path.join( # TODO use relative path to this
|
123 |
'transcripts', transcript_type, f'{video_id}.json')
|
124 |
-
|
|
|
125 |
try:
|
126 |
if os.path.exists(transcript_path): # Load from file
|
127 |
with open(transcript_path) as fp:
|
128 |
-
words = json.load(fp)
|
129 |
|
130 |
elif process:
|
131 |
transcript_list = list_transcripts(video_id)
|
@@ -138,52 +146,55 @@ def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
|
|
138 |
except (TooManyRequests, YouTubeRequestFailed, requests.exceptions.ConnectionError) as e: # Can retry
|
139 |
print(e)
|
140 |
time.sleep(10) # Timeout
|
141 |
-
return get_words(video_id, process,
|
142 |
|
143 |
except CouldNotRetrieveTranscript:
|
144 |
pass
|
|
|
145 |
except json.decoder.JSONDecodeError:
|
146 |
print('JSONDecodeError for', video_id)
|
147 |
os.remove(transcript_path) # Remove file and try again
|
148 |
-
return get_words(video_id, process,
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
json.dump(words, fp)
|
153 |
|
154 |
-
if
|
155 |
-
|
156 |
|
157 |
return words
|
158 |
|
159 |
|
160 |
# TODO make min_sponsor_segment_length param
|
|
|
161 |
def extract_sponsors(words, min_sponsor_segment_length=3):
|
162 |
-
if not words:
|
163 |
return []
|
164 |
|
165 |
paragraphs = []
|
166 |
current = []
|
167 |
prev_category = None
|
168 |
|
169 |
-
i
|
170 |
-
|
171 |
-
unimportant = i == len(words) or words[i]['category'] is None
|
172 |
|
173 |
-
if unimportant or words[i]
|
174 |
if current: # Save the current batch
|
175 |
paragraphs.append({
|
176 |
'words': current,
|
177 |
-
'category': current[-1]
|
178 |
})
|
179 |
|
180 |
current = []
|
181 |
|
182 |
if not unimportant: # Some useful information to save
|
183 |
current.append(words[i])
|
184 |
-
prev_category = words[i]
|
185 |
-
|
186 |
-
i += 1
|
187 |
|
188 |
# Remove all too short:
|
189 |
return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs))
|
@@ -277,24 +288,20 @@ class PreprocessArguments:
|
|
277 |
min_views: int = field(
|
278 |
default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'})
|
279 |
|
|
|
|
|
|
|
280 |
min_date: str = field(
|
281 |
-
# release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0)
|
282 |
-
|
283 |
-
|
284 |
# default='01/10/2020', # No more autovote
|
285 |
-
metadata={'help': 'Only use submissions from after this date'})
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
metadata={
|
291 |
-
'nargs': '+',
|
292 |
-
'choices': ['intro', 'sponsor', 'interaction']
|
293 |
-
# 'outro', 'selfpromo', 'preview',
|
294 |
-
# 'poi_highlight', 'filler', 'music_offtopic',
|
295 |
-
# 'moreCategories'
|
296 |
-
}
|
297 |
-
)
|
298 |
|
299 |
do_transcribe: bool = field(
|
300 |
default=False, metadata={'help': 'Get transcripts for videos'}
|
@@ -302,9 +309,9 @@ class PreprocessArguments:
|
|
302 |
num_jobs: int = field(
|
303 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
304 |
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
|
309 |
do_generate: bool = field(
|
310 |
default=False, metadata={'help': 'Generate labelled data.'}
|
@@ -381,22 +388,6 @@ def download_file(url, filename):
|
|
381 |
return total_bytes == os.path.getsize(filename)
|
382 |
|
383 |
|
384 |
-
@dataclass
|
385 |
-
class ProcessedArguments:
|
386 |
-
processed_dir: Optional[str] = field(
|
387 |
-
default='processed',
|
388 |
-
metadata={
|
389 |
-
'help': 'Processed data directory'
|
390 |
-
},
|
391 |
-
)
|
392 |
-
processed_file: Optional[str] = field(
|
393 |
-
default='final.json',
|
394 |
-
metadata={
|
395 |
-
'help': 'Processed data file'
|
396 |
-
},
|
397 |
-
)
|
398 |
-
|
399 |
-
|
400 |
def load_datasets(dataset_args):
|
401 |
print('Reading datasets')
|
402 |
data_files = {}
|
@@ -411,7 +402,7 @@ def load_datasets(dataset_args):
|
|
411 |
data_files['test'] = os.path.join(
|
412 |
dataset_args.data_dir, dataset_args.test_file)
|
413 |
|
414 |
-
return load_dataset('json', data_files=data_files)
|
415 |
|
416 |
|
417 |
@dataclass
|
@@ -422,6 +413,18 @@ class DatasetArguments:
|
|
422 |
'help': 'The directory which stores train, test and/or validation data.'
|
423 |
},
|
424 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
train_file: Optional[str] = field(
|
427 |
default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
@@ -444,7 +447,12 @@ class DatasetArguments:
|
|
444 |
'help': 'The excess segments left after the split'
|
445 |
},
|
446 |
)
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
448 |
overwrite_cache: bool = field(
|
449 |
default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
|
450 |
)
|
@@ -472,13 +480,12 @@ def main():
|
|
472 |
# Generate final.json from sponsorTimes.csv
|
473 |
hf_parser = HfArgumentParser((
|
474 |
PreprocessArguments,
|
475 |
-
ProcessedArguments,
|
476 |
DatasetArguments,
|
477 |
segment.SegmentationArguments,
|
478 |
ModelArguments,
|
479 |
GeneralArguments
|
480 |
))
|
481 |
-
preprocess_args,
|
482 |
|
483 |
raw_dataset_path = os.path.join(
|
484 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
@@ -491,28 +498,28 @@ def main():
|
|
491 |
break
|
492 |
print('Failed, trying next')
|
493 |
|
494 |
-
|
495 |
-
|
496 |
-
print('Parsing raw database')
|
497 |
-
db = {}
|
498 |
|
499 |
-
|
|
|
|
|
|
|
|
|
|
|
500 |
|
|
|
501 |
with open(raw_dataset_path, newline='') as csvfile:
|
502 |
reader = csv.DictReader(csvfile)
|
503 |
|
504 |
for line in reader:
|
505 |
-
submission_time = float(line['timeSubmitted'])/1e3
|
506 |
-
|
507 |
-
if datetime.fromtimestamp(submission_time) < latest_time:
|
508 |
-
continue
|
509 |
|
510 |
if line['service'] != 'YouTube':
|
511 |
continue
|
512 |
if len(line['videoID']) != 11:
|
513 |
continue # Invalid youtube video ID
|
514 |
|
515 |
-
if line['category'] not in
|
516 |
continue
|
517 |
if line['actionType'] != 'skip':
|
518 |
continue
|
@@ -522,17 +529,18 @@ def main():
|
|
522 |
continue
|
523 |
|
524 |
# Skip those that aren't highly voted
|
525 |
-
|
526 |
-
if
|
527 |
continue
|
528 |
|
529 |
locked = line['locked'] == '1'
|
530 |
|
531 |
-
|
532 |
-
#
|
533 |
-
|
534 |
-
|
535 |
-
|
|
|
536 |
|
537 |
if line['videoID'] not in db:
|
538 |
db[line['videoID']] = []
|
@@ -541,15 +549,37 @@ def main():
|
|
541 |
'uuid': line['UUID'],
|
542 |
'start': float(line['startTime']),
|
543 |
'end': float(line['endTime']),
|
544 |
-
'votes':
|
545 |
'locked': locked,
|
546 |
-
'views': line['views'],
|
547 |
-
'submission_time':
|
548 |
-
'reputation':
|
549 |
'category': line['category'],
|
550 |
-
'action': line['actionType'],
|
551 |
})
|
552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
num_segments = 0
|
554 |
|
555 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
@@ -559,20 +589,21 @@ def main():
|
|
559 |
num_segments += len(db[key])
|
560 |
print('Saved', len(db), 'videos and', num_segments, 'segments')
|
561 |
|
|
|
|
|
|
|
562 |
return db
|
563 |
|
564 |
# 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
|
565 |
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
566 |
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
567 |
-
parsed_database = None
|
568 |
if preprocess_args.do_transcribe:
|
569 |
print('Collecting videos')
|
570 |
parsed_database = read_db()
|
571 |
|
572 |
# Remove transcripts already processed
|
573 |
-
finished = set(
|
574 |
-
|
575 |
-
finished = set([x.split('.')[0] for x in finished])
|
576 |
|
577 |
video_ids = list(parsed_database.keys() - finished)
|
578 |
|
@@ -592,7 +623,7 @@ def main():
|
|
592 |
tasks, preprocess_args.num_jobs, callback).start()
|
593 |
|
594 |
final_path = os.path.join(
|
595 |
-
|
596 |
|
597 |
if preprocess_args.do_create:
|
598 |
print('Create final data')
|
@@ -601,22 +632,19 @@ def main():
|
|
601 |
|
602 |
parsed_database = read_db()
|
603 |
|
604 |
-
# TODO add progress bar
|
605 |
# TODO parallelise?
|
606 |
with tqdm(total=len(parsed_database)) as progress:
|
607 |
for index, (video_id, segments) in enumerate(parsed_database.items()):
|
608 |
-
|
609 |
if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
|
610 |
break
|
611 |
progress.set_description(f'Processing {video_id}')
|
612 |
progress.update()
|
613 |
|
614 |
-
final_data[video_id] = []
|
615 |
-
|
616 |
video_words = get_words(video_id, process=False)
|
617 |
if not video_words:
|
618 |
continue
|
619 |
|
|
|
620 |
for seg in segments: # Only add segments with high enough wps
|
621 |
segment_words = segment.extract_segment(
|
622 |
video_words, seg['start'], seg['end'])
|
@@ -634,7 +662,10 @@ def main():
|
|
634 |
# e.g. music ads with some words on each side
|
635 |
# progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})')
|
636 |
continue
|
637 |
-
|
|
|
|
|
|
|
638 |
|
639 |
# Save data
|
640 |
with open(final_path, 'w') as fp:
|
@@ -666,13 +697,12 @@ def main():
|
|
666 |
|
667 |
if preprocess_args.do_generate:
|
668 |
print('Generating')
|
669 |
-
from model import get_tokenizer
|
670 |
-
|
671 |
# max_videos=preprocess_args.max_videos,
|
672 |
# max_segments=preprocess_args.max_segments,
|
673 |
# , max_videos, max_segments
|
674 |
|
675 |
-
|
|
|
676 |
|
677 |
# TODO
|
678 |
# count_videos = 0
|
@@ -685,8 +715,9 @@ def main():
|
|
685 |
|
686 |
data = list(itertools.islice(data, start_index, end_index))
|
687 |
|
688 |
-
|
689 |
-
|
|
|
690 |
tqdm(data) as progress:
|
691 |
|
692 |
for offset, (video_id, sponsor_segments) in enumerate(data):
|
@@ -711,36 +742,36 @@ def main():
|
|
711 |
continue
|
712 |
|
713 |
for seg in segments:
|
714 |
-
|
715 |
-
|
716 |
-
|
|
|
717 |
|
718 |
-
# Ignore segments with "not enough words" in the transcript
|
719 |
-
# Must do here since this includes non-sponsor segments
|
720 |
-
if wps < preprocess_args.min_wps:
|
721 |
-
|
722 |
|
723 |
d = {
|
724 |
'video_index': offset + start_index,
|
725 |
'video_id': video_id,
|
726 |
-
'text':
|
727 |
-
'
|
|
|
728 |
}
|
729 |
|
730 |
extracted_segments = extract_sponsors(seg)
|
731 |
if extracted_segments:
|
732 |
extracted_texts = []
|
733 |
for s in extracted_segments:
|
734 |
-
w = ' '.join(q['
|
735 |
category = s['category'].upper()
|
736 |
extracted_texts.append(
|
737 |
f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
|
738 |
)
|
739 |
|
740 |
-
|
741 |
extracted_texts)
|
742 |
-
|
743 |
-
d['extracted'] = clean_text(extracted_text)
|
744 |
print(json.dumps(d), file=positive)
|
745 |
|
746 |
else:
|
@@ -824,14 +855,12 @@ def main():
|
|
824 |
|
825 |
|
826 |
def split(arr, ratios):
|
827 |
-
"""Split array according to ratios. Sum of ratios should be
|
828 |
-
|
829 |
to_return = []
|
830 |
|
831 |
cumulative_sum = 0
|
832 |
for r in ratios:
|
833 |
current = cumulative_sum
|
834 |
-
|
835 |
cumulative_sum += r * len(arr)
|
836 |
to_return.append(arr[int(current):int(cumulative_sum)])
|
837 |
|
|
|
1 |
+
from shared import CATGEGORY_OPTIONS
|
2 |
from utils import jaccard
|
3 |
from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
|
4 |
from functools import lru_cache
|
|
|
114 |
return YouTubeTranscriptApi.list_transcripts(video_id)
|
115 |
|
116 |
|
117 |
+
WORDS_TO_REMOVE = [
|
118 |
+
CustomTokens.MUSIC.value,
|
119 |
+
CustomTokens.APPLAUSE.value,
|
120 |
+
CustomTokens.LAUGHTER.value
|
121 |
+
]
|
122 |
+
|
123 |
+
|
124 |
@lru_cache(maxsize=16)
|
125 |
+
def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True):
|
126 |
"""Get parsed video transcript with caching system
|
127 |
returns None if not processed yet and process is False
|
128 |
"""
|
|
|
129 |
transcript_path = os.path.join( # TODO use relative path to this
|
130 |
'transcripts', transcript_type, f'{video_id}.json')
|
131 |
+
|
132 |
+
words = None
|
133 |
try:
|
134 |
if os.path.exists(transcript_path): # Load from file
|
135 |
with open(transcript_path) as fp:
|
136 |
+
words = json.load(fp) # May be empty
|
137 |
|
138 |
elif process:
|
139 |
transcript_list = list_transcripts(video_id)
|
|
|
146 |
except (TooManyRequests, YouTubeRequestFailed, requests.exceptions.ConnectionError) as e: # Can retry
|
147 |
print(e)
|
148 |
time.sleep(10) # Timeout
|
149 |
+
return get_words(video_id, process, transcript_type, fallback)
|
150 |
|
151 |
except CouldNotRetrieveTranscript:
|
152 |
pass
|
153 |
+
|
154 |
except json.decoder.JSONDecodeError:
|
155 |
print('JSONDecodeError for', video_id)
|
156 |
os.remove(transcript_path) # Remove file and try again
|
157 |
+
return get_words(video_id, process, transcript_type, fallback)
|
158 |
+
|
159 |
+
# Tried to process it, but it was empty...
|
160 |
+
if process and not os.path.exists(transcript_path):
|
161 |
+
with open(transcript_path, 'w') as fp:
|
162 |
+
json.dump(words, fp)
|
163 |
|
164 |
+
if not words and fallback is not None:
|
165 |
+
return get_words(video_id, process, transcript_type=fallback, fallback=None)
|
|
|
166 |
|
167 |
+
if words and filter_words_to_remove:
|
168 |
+
words = list(filter(lambda x: x['text'] not in WORDS_TO_REMOVE, words))
|
169 |
|
170 |
return words
|
171 |
|
172 |
|
173 |
# TODO make min_sponsor_segment_length param
|
174 |
+
# TODO rename to extract_segments
|
175 |
def extract_sponsors(words, min_sponsor_segment_length=3):
|
176 |
+
if not words or len(words) < min_sponsor_segment_length:
|
177 |
return []
|
178 |
|
179 |
paragraphs = []
|
180 |
current = []
|
181 |
prev_category = None
|
182 |
|
183 |
+
for i in range(len(words) + 1):
|
184 |
+
unimportant = i == len(words) or words[i].get('category') is None
|
|
|
185 |
|
186 |
+
if unimportant or words[i].get('category') != prev_category:
|
187 |
if current: # Save the current batch
|
188 |
paragraphs.append({
|
189 |
'words': current,
|
190 |
+
'category': current[-1].get('category'),
|
191 |
})
|
192 |
|
193 |
current = []
|
194 |
|
195 |
if not unimportant: # Some useful information to save
|
196 |
current.append(words[i])
|
197 |
+
prev_category = words[i].get('category')
|
|
|
|
|
198 |
|
199 |
# Remove all too short:
|
200 |
return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs))
|
|
|
288 |
min_views: int = field(
|
289 |
default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'})
|
290 |
|
291 |
+
# min_reputation: int = field(
|
292 |
+
# default=0, metadata={'help': 'Minimum reputation a user must have for the segment to be included'})
|
293 |
+
|
294 |
min_date: str = field(
|
295 |
+
# default='08/06/2020', # release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0)
|
296 |
+
# release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)
|
297 |
+
default='20/08/2021',
|
298 |
# default='01/10/2020', # No more autovote
|
299 |
+
metadata={'help': 'Only use submissions from after this date (inclusive)'})
|
300 |
|
301 |
+
max_date: str = field(
|
302 |
+
# default='01/01/9999', # Include all
|
303 |
+
default='27/01/2022',
|
304 |
+
metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
do_transcribe: bool = field(
|
307 |
default=False, metadata={'help': 'Get transcripts for videos'}
|
|
|
309 |
num_jobs: int = field(
|
310 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
311 |
|
312 |
+
overwrite: bool = field(
|
313 |
+
default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
|
314 |
+
)
|
315 |
|
316 |
do_generate: bool = field(
|
317 |
default=False, metadata={'help': 'Generate labelled data.'}
|
|
|
388 |
return total_bytes == os.path.getsize(filename)
|
389 |
|
390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
def load_datasets(dataset_args):
|
392 |
print('Reading datasets')
|
393 |
data_files = {}
|
|
|
402 |
data_files['test'] = os.path.join(
|
403 |
dataset_args.data_dir, dataset_args.test_file)
|
404 |
|
405 |
+
return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
|
406 |
|
407 |
|
408 |
@dataclass
|
|
|
413 |
'help': 'The directory which stores train, test and/or validation data.'
|
414 |
},
|
415 |
)
|
416 |
+
processed_file: Optional[str] = field(
|
417 |
+
default='segments.json',
|
418 |
+
metadata={
|
419 |
+
'help': 'Processed data file'
|
420 |
+
},
|
421 |
+
)
|
422 |
+
processed_database: Optional[str] = field(
|
423 |
+
default='processed_database.json',
|
424 |
+
metadata={
|
425 |
+
'help': 'Processed database file'
|
426 |
+
},
|
427 |
+
)
|
428 |
|
429 |
train_file: Optional[str] = field(
|
430 |
default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
|
|
447 |
'help': 'The excess segments left after the split'
|
448 |
},
|
449 |
)
|
450 |
+
dataset_cache_dir: Optional[str] = field(
|
451 |
+
default=None,
|
452 |
+
metadata={
|
453 |
+
'help': 'Where to store the cached datasets'
|
454 |
+
},
|
455 |
+
)
|
456 |
overwrite_cache: bool = field(
|
457 |
default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
|
458 |
)
|
|
|
480 |
# Generate final.json from sponsorTimes.csv
|
481 |
hf_parser = HfArgumentParser((
|
482 |
PreprocessArguments,
|
|
|
483 |
DatasetArguments,
|
484 |
segment.SegmentationArguments,
|
485 |
ModelArguments,
|
486 |
GeneralArguments
|
487 |
))
|
488 |
+
preprocess_args, dataset_args, segmentation_args, model_args, _ = hf_parser.parse_args_into_dataclasses()
|
489 |
|
490 |
raw_dataset_path = os.path.join(
|
491 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
|
|
498 |
break
|
499 |
print('Failed, trying next')
|
500 |
|
501 |
+
processed_db_path = os.path.join(
|
502 |
+
dataset_args.data_dir, dataset_args.processed_database)
|
|
|
|
|
503 |
|
504 |
+
def read_db():
|
505 |
+
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
506 |
+
with open(processed_db_path) as fp:
|
507 |
+
return json.load(fp)
|
508 |
+
print('Processing raw database')
|
509 |
+
db = {}
|
510 |
|
511 |
+
allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
|
512 |
with open(raw_dataset_path, newline='') as csvfile:
|
513 |
reader = csv.DictReader(csvfile)
|
514 |
|
515 |
for line in reader:
|
|
|
|
|
|
|
|
|
516 |
|
517 |
if line['service'] != 'YouTube':
|
518 |
continue
|
519 |
if len(line['videoID']) != 11:
|
520 |
continue # Invalid youtube video ID
|
521 |
|
522 |
+
if line['category'] not in allowed_categories:
|
523 |
continue
|
524 |
if line['actionType'] != 'skip':
|
525 |
continue
|
|
|
529 |
continue
|
530 |
|
531 |
# Skip those that aren't highly voted
|
532 |
+
votes = int(line['votes'])
|
533 |
+
if votes < preprocess_args.min_votes:
|
534 |
continue
|
535 |
|
536 |
locked = line['locked'] == '1'
|
537 |
|
538 |
+
reputation = float(line['reputation'])
|
539 |
+
# if reputation < preprocess_args.min_reputation:
|
540 |
+
# continue # TODO add back?
|
541 |
+
# Problems like mGVn1wCkBrE
|
542 |
+
|
543 |
+
# TODO ignore if over max_duration
|
544 |
|
545 |
if line['videoID'] not in db:
|
546 |
db[line['videoID']] = []
|
|
|
549 |
'uuid': line['UUID'],
|
550 |
'start': float(line['startTime']),
|
551 |
'end': float(line['endTime']),
|
552 |
+
'votes': votes,
|
553 |
'locked': locked,
|
554 |
+
'views': int(line['views']),
|
555 |
+
'submission_time': float(line['timeSubmitted'])/1e3,
|
556 |
+
'reputation': reputation,
|
557 |
'category': line['category'],
|
558 |
+
# 'action': line['actionType'],
|
559 |
})
|
560 |
|
561 |
+
# We now remove whole videos from the list
|
562 |
+
# Helps with obtaining "fully-labelled" videos
|
563 |
+
min_date = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
|
564 |
+
max_date = datetime.strptime(preprocess_args.max_date, '%d/%m/%Y')
|
565 |
+
for key in list(db):
|
566 |
+
|
567 |
+
if any(datetime.fromtimestamp(x['submission_time']) < min_date for x in db[key]):
|
568 |
+
# Remove videos where any of its segments were submitted before min_date
|
569 |
+
# (essentially removes videos uploaded before min_date)
|
570 |
+
# Prevents issues where some segments of a video are excluded
|
571 |
+
del db[key]
|
572 |
+
elif all(datetime.fromtimestamp(x['submission_time']) > max_date for x in db[key]):
|
573 |
+
# Remove videos where all of its segments were submitted after max_date
|
574 |
+
# (essentially removes videos uploaded after max_date)
|
575 |
+
# Allows for segments to be corrected for past videos
|
576 |
+
del db[key]
|
577 |
+
elif any(not x['locked'] and x['views'] < preprocess_args.min_views for x in db[key]):
|
578 |
+
# Remove videos where any of its non-locked segments do not have enough views
|
579 |
+
# (essentially skips videos that have not been fully watched/reviewed)
|
580 |
+
# Always include segments locked by VIPs, regardless of view count
|
581 |
+
del db[key]
|
582 |
+
|
583 |
num_segments = 0
|
584 |
|
585 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
|
|
589 |
num_segments += len(db[key])
|
590 |
print('Saved', len(db), 'videos and', num_segments, 'segments')
|
591 |
|
592 |
+
with open(processed_db_path, 'w') as fp:
|
593 |
+
json.dump(db, fp)
|
594 |
+
|
595 |
return db
|
596 |
|
597 |
# 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
|
598 |
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
599 |
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
|
|
600 |
if preprocess_args.do_transcribe:
|
601 |
print('Collecting videos')
|
602 |
parsed_database = read_db()
|
603 |
|
604 |
# Remove transcripts already processed
|
605 |
+
finished = set(x.split('.')[0] for x in os.listdir(
|
606 |
+
'transcripts/auto/') + os.listdir('transcripts/manual/'))
|
|
|
607 |
|
608 |
video_ids = list(parsed_database.keys() - finished)
|
609 |
|
|
|
623 |
tasks, preprocess_args.num_jobs, callback).start()
|
624 |
|
625 |
final_path = os.path.join(
|
626 |
+
dataset_args.data_dir, dataset_args.processed_file)
|
627 |
|
628 |
if preprocess_args.do_create:
|
629 |
print('Create final data')
|
|
|
632 |
|
633 |
parsed_database = read_db()
|
634 |
|
|
|
635 |
# TODO parallelise?
|
636 |
with tqdm(total=len(parsed_database)) as progress:
|
637 |
for index, (video_id, segments) in enumerate(parsed_database.items()):
|
|
|
638 |
if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
|
639 |
break
|
640 |
progress.set_description(f'Processing {video_id}')
|
641 |
progress.update()
|
642 |
|
|
|
|
|
643 |
video_words = get_words(video_id, process=False)
|
644 |
if not video_words:
|
645 |
continue
|
646 |
|
647 |
+
final_vid_segs = []
|
648 |
for seg in segments: # Only add segments with high enough wps
|
649 |
segment_words = segment.extract_segment(
|
650 |
video_words, seg['start'], seg['end'])
|
|
|
662 |
# e.g. music ads with some words on each side
|
663 |
# progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})')
|
664 |
continue
|
665 |
+
final_vid_segs.append(seg)
|
666 |
+
|
667 |
+
if final_vid_segs:
|
668 |
+
final_data[video_id] = final_vid_segs
|
669 |
|
670 |
# Save data
|
671 |
with open(final_path, 'w') as fp:
|
|
|
697 |
|
698 |
if preprocess_args.do_generate:
|
699 |
print('Generating')
|
|
|
|
|
700 |
# max_videos=preprocess_args.max_videos,
|
701 |
# max_segments=preprocess_args.max_segments,
|
702 |
# , max_videos, max_segments
|
703 |
|
704 |
+
from model import get_model_tokenizer
|
705 |
+
model, tokenizer = get_model_tokenizer(model_args.model_name_or_path)
|
706 |
|
707 |
# TODO
|
708 |
# count_videos = 0
|
|
|
715 |
|
716 |
data = list(itertools.islice(data, start_index, end_index))
|
717 |
|
718 |
+
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
719 |
+
with open(positive_file, write_mode, encoding='utf-8') as positive, \
|
720 |
+
open(negative_file, write_mode, encoding='utf-8') as negative, \
|
721 |
tqdm(data) as progress:
|
722 |
|
723 |
for offset, (video_id, sponsor_segments) in enumerate(data):
|
|
|
742 |
continue
|
743 |
|
744 |
for seg in segments:
|
745 |
+
seg_start = segment.word_start(seg[0])
|
746 |
+
seg_end = segment.word_end(seg[-1])
|
747 |
+
# duration = seg_end - seg_start
|
748 |
+
# wps = len(seg)/duration if duration > 0 else 0
|
749 |
|
750 |
+
# # Ignore segments with "not enough words" in the transcript
|
751 |
+
# # Must do here since this includes non-sponsor segments
|
752 |
+
# if wps < preprocess_args.min_wps:
|
753 |
+
# continue
|
754 |
|
755 |
d = {
|
756 |
'video_index': offset + start_index,
|
757 |
'video_id': video_id,
|
758 |
+
'text': ' '.join(x['cleaned'] for x in seg),
|
759 |
+
'start': seg_start,
|
760 |
+
'end': seg_end,
|
761 |
}
|
762 |
|
763 |
extracted_segments = extract_sponsors(seg)
|
764 |
if extracted_segments:
|
765 |
extracted_texts = []
|
766 |
for s in extracted_segments:
|
767 |
+
w = ' '.join(q['cleaned'] for q in s['words'])
|
768 |
category = s['category'].upper()
|
769 |
extracted_texts.append(
|
770 |
f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
|
771 |
)
|
772 |
|
773 |
+
d['extracted'] = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
|
774 |
extracted_texts)
|
|
|
|
|
775 |
print(json.dumps(d), file=positive)
|
776 |
|
777 |
else:
|
|
|
855 |
|
856 |
|
857 |
def split(arr, ratios):
|
858 |
+
"""Split array according to ratios. Sum of ratios should be <= 1"""
|
|
|
859 |
to_return = []
|
860 |
|
861 |
cumulative_sum = 0
|
862 |
for r in ratios:
|
863 |
current = cumulative_sum
|
|
|
864 |
cumulative_sum += r * len(arr)
|
865 |
to_return.append(arr[int(current):int(cumulative_sum)])
|
866 |
|
src/segment.py
CHANGED
@@ -5,27 +5,19 @@ from dataclasses import dataclass, field
|
|
5 |
|
6 |
@dataclass
|
7 |
class SegmentationArguments:
|
8 |
-
pause_threshold: int = field(default=2, metadata={
|
9 |
'help': 'When the time between words is greater than pause threshold, force into a new segment'})
|
10 |
|
11 |
|
12 |
-
# WORDS TO ALWAYS HAVE ON THEIR OWN
|
13 |
-
# always_split_re = re.compile(r'\[\w+\]')
|
14 |
-
# e.g., [Laughter], [Applause], [Music]
|
15 |
-
always_split = [
|
16 |
-
CustomTokens.MUSIC.value,
|
17 |
-
CustomTokens.APPLAUSE.value,
|
18 |
-
CustomTokens.LAUGHTER.value
|
19 |
-
]
|
20 |
-
|
21 |
-
|
22 |
def get_overlapping_chunks_of_tokens(tokens, size, overlap):
|
23 |
for i in range(0, len(tokens), size-overlap+1):
|
24 |
yield tokens[i:i+size]
|
25 |
|
26 |
|
27 |
-
# Generate up to max_tokens
|
28 |
-
|
|
|
|
|
29 |
|
30 |
|
31 |
# TODO play around with this?
|
@@ -34,15 +26,9 @@ OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
|
|
34 |
|
35 |
def add_labels_to_words(words, sponsor_segments):
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
for sponsor_segment in sponsor_segments:
|
41 |
-
if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
|
42 |
-
word['category'] = sponsor_segment['category']
|
43 |
-
|
44 |
-
# TODO use extract_segment with mapping function?
|
45 |
-
# TODO remove sponsor segments that contain mostly empty space?
|
46 |
|
47 |
return words
|
48 |
|
@@ -69,84 +55,86 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
69 |
|
70 |
for index, word in enumerate(words):
|
71 |
# Get length of tokenized word
|
72 |
-
cleaned = preprocess.clean_text(word['text'])
|
73 |
word['num_tokens'] = len(
|
74 |
-
tokenizer(cleaned, add_special_tokens=False, truncation=True).input_ids)
|
75 |
-
|
76 |
-
add_new_segment = index == 0
|
77 |
-
if not add_new_segment:
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
# Pause too small, do not split
|
83 |
-
elif word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
84 |
-
add_new_segment = True
|
85 |
-
|
86 |
-
if add_new_segment: # New segment
|
87 |
first_pass_segments.append([word])
|
88 |
|
89 |
else: # Add to current segment
|
90 |
first_pass_segments[-1].append(word)
|
91 |
|
92 |
-
max_q_size = tokenizer.model_max_length
|
93 |
|
94 |
buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
|
95 |
|
96 |
# In second pass, we split those segments if too big
|
97 |
second_pass_segments = []
|
|
|
98 |
for segment in first_pass_segments:
|
99 |
current_segment_num_tokens = 0
|
100 |
current_segment = []
|
|
|
101 |
for word in segment:
|
102 |
-
new_seg = current_segment_num_tokens +
|
|
|
103 |
if new_seg:
|
104 |
# Adding this token would make it have too many tokens
|
105 |
# We save this batch and create new
|
106 |
-
second_pass_segments.append(current_segment
|
107 |
|
108 |
# Add tokens to current segment
|
109 |
current_segment.append(word)
|
110 |
current_segment_num_tokens += word['num_tokens']
|
111 |
|
112 |
-
if new_seg:
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
if current_segment:
|
119 |
-
second_pass_segments.append(current_segment
|
120 |
|
121 |
# Cleaning up, delete 'num_tokens' from each word
|
122 |
-
for segment in second_pass_segments:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
return second_pass_segments
|
127 |
|
128 |
|
129 |
def extract_segment(words, start, end, map_function=None):
|
130 |
"""Extracts all words with time in [start, end]"""
|
131 |
-
|
132 |
a = binary_search(words, 0, len(words), start, True)
|
133 |
-
b = min(binary_search(words, 0, len(words), end
|
134 |
|
135 |
to_transform = map_function is not None and callable(map_function)
|
136 |
-
|
137 |
return [
|
138 |
map_function(words[i]) if to_transform else words[i] for i in range(a, b)
|
139 |
]
|
140 |
|
141 |
-
|
142 |
def binary_search(words, start_index, end_index, time, below):
|
|
|
143 |
if start_index >= end_index:
|
144 |
return end_index
|
145 |
-
|
146 |
-
middle_index = (start_index + end_index ) // 2
|
147 |
|
148 |
-
|
|
|
|
|
|
|
149 |
|
|
|
150 |
if time <= middle_time:
|
151 |
return binary_search(words, start_index, middle_index, time, below)
|
152 |
else:
|
|
|
5 |
|
6 |
@dataclass
|
7 |
class SegmentationArguments:
|
8 |
+
pause_threshold: int = field(default=2.5, metadata={
|
9 |
'help': 'When the time between words is greater than pause threshold, force into a new segment'})
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def get_overlapping_chunks_of_tokens(tokens, size, overlap):
|
13 |
for i in range(0, len(tokens), size-overlap+1):
|
14 |
yield tokens[i:i+size]
|
15 |
|
16 |
|
17 |
+
# Generate up to SAFETY_TOKENS_PERCENTAGE*max_tokens tokens
|
18 |
+
MIN_SAFETY_TOKENS = 8
|
19 |
+
SAFETY_TOKENS_PERCENTAGE = 0.9765625
|
20 |
+
# e.g. 512 -> 500, 768 -> 750
|
21 |
|
22 |
|
23 |
# TODO play around with this?
|
|
|
26 |
|
27 |
def add_labels_to_words(words, sponsor_segments):
|
28 |
|
29 |
+
for sponsor_segment in sponsor_segments:
|
30 |
+
for w in extract_segment(words, sponsor_segment['start'], sponsor_segment['end']):
|
31 |
+
w['category'] = sponsor_segment['category']
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
return words
|
34 |
|
|
|
55 |
|
56 |
for index, word in enumerate(words):
|
57 |
# Get length of tokenized word
|
58 |
+
word['cleaned'] = preprocess.clean_text(word['text'])
|
59 |
word['num_tokens'] = len(
|
60 |
+
tokenizer(word['cleaned'], add_special_tokens=False, truncation=True).input_ids)
|
|
|
|
|
|
|
61 |
|
62 |
+
# Add new segment
|
63 |
+
if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
first_pass_segments.append([word])
|
65 |
|
66 |
else: # Add to current segment
|
67 |
first_pass_segments[-1].append(word)
|
68 |
|
69 |
+
max_q_size = round(SAFETY_TOKENS_PERCENTAGE * tokenizer.model_max_length)
|
70 |
|
71 |
buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
|
72 |
|
73 |
# In second pass, we split those segments if too big
|
74 |
second_pass_segments = []
|
75 |
+
|
76 |
for segment in first_pass_segments:
|
77 |
current_segment_num_tokens = 0
|
78 |
current_segment = []
|
79 |
+
|
80 |
for word in segment:
|
81 |
+
new_seg = current_segment_num_tokens + \
|
82 |
+
word['num_tokens'] >= max_q_size
|
83 |
if new_seg:
|
84 |
# Adding this token would make it have too many tokens
|
85 |
# We save this batch and create new
|
86 |
+
second_pass_segments.append(current_segment)
|
87 |
|
88 |
# Add tokens to current segment
|
89 |
current_segment.append(word)
|
90 |
current_segment_num_tokens += word['num_tokens']
|
91 |
|
92 |
+
if not new_seg:
|
93 |
+
continue
|
94 |
+
|
95 |
+
# Just created a new segment, so we remove until we only have buffer_size tokens
|
96 |
+
last_index = 0
|
97 |
+
while current_segment_num_tokens > buffer_size and current_segment:
|
98 |
+
current_segment_num_tokens -= current_segment[last_index]['num_tokens']
|
99 |
+
last_index += 1
|
100 |
+
|
101 |
+
current_segment = current_segment[last_index:]
|
102 |
|
103 |
+
if current_segment: # Add remaining segment
|
104 |
+
second_pass_segments.append(current_segment)
|
105 |
|
106 |
# Cleaning up, delete 'num_tokens' from each word
|
107 |
+
# for segment in second_pass_segments:
|
108 |
+
for word in words:
|
109 |
+
word.pop('num_tokens', None)
|
110 |
+
|
111 |
return second_pass_segments
|
112 |
|
113 |
|
114 |
def extract_segment(words, start, end, map_function=None):
|
115 |
"""Extracts all words with time in [start, end]"""
|
116 |
+
|
117 |
a = binary_search(words, 0, len(words), start, True)
|
118 |
+
b = min(binary_search(words, 0, len(words), end, False) + 1, len(words))
|
119 |
|
120 |
to_transform = map_function is not None and callable(map_function)
|
121 |
+
|
122 |
return [
|
123 |
map_function(words[i]) if to_transform else words[i] for i in range(a, b)
|
124 |
]
|
125 |
|
126 |
+
|
127 |
def binary_search(words, start_index, end_index, time, below):
|
128 |
+
"""Binary search to get first index of word whose start/end time is greater/less than some value"""
|
129 |
if start_index >= end_index:
|
130 |
return end_index
|
|
|
|
|
131 |
|
132 |
+
middle_index = (start_index + end_index) // 2
|
133 |
+
|
134 |
+
middle_time = word_start(
|
135 |
+
words[middle_index]) if below else word_end(words[middle_index])
|
136 |
|
137 |
+
# TODO if above: if time < middle_time binary_search(start, middle-1)
|
138 |
if time <= middle_time:
|
139 |
return binary_search(words, start_index, middle_index, time, below)
|
140 |
else:
|