Joshua Lochner commited on
Commit
3879103
·
1 Parent(s): 915339e

Improve preprocessing

Browse files
Files changed (1) hide show
  1. src/preprocess.py +190 -175
src/preprocess.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from datetime import datetime
2
  import itertools
3
  from typing import Optional, List
@@ -13,12 +16,12 @@ import re
13
  import random
14
  import logging
15
  from youtube_transcript_api import YouTubeTranscriptApi
16
- from youtube_transcript_api._errors import CouldNotRetrieveTranscript, YouTubeRequestFailed
17
  import os
18
  import json
19
  import time
20
  import requests
21
- from utils import InterruptibleThreadPool, Job
22
 
23
 
24
  def find(s, ch):
@@ -106,87 +109,84 @@ def get_auto_words(transcript_list):
106
  return words
107
 
108
 
 
 
 
 
 
109
  def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
110
  """Get parsed video transcript with caching system
111
  returns None if not processed yet and process is False
112
  """
113
  get_manual_if_fail = fallback and transcript_type == 'auto'
114
- transcript_path = os.path.join(
115
  'transcripts', transcript_type, f'{video_id}.json')
116
  words = []
117
  try:
118
- if os.path.exists(transcript_path):
119
  with open(transcript_path) as fp:
120
- wds = json.load(fp)
121
-
122
- if not wds and get_manual_if_fail:
123
- return get_words(video_id, process, fallback, 'manual')
124
- return wds
125
 
126
- elif not process:
127
- return None
128
 
129
- transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
130
-
131
- if transcript_type == 'manual':
132
- words = get_manual_words(transcript_list)
133
- else:
134
- words = get_auto_words(transcript_list)
135
 
136
- except YouTubeRequestFailed as e:
137
  print(e)
138
- time.sleep(30) # Timeout
139
  return get_words(video_id, process, fallback, transcript_type)
140
 
141
  except CouldNotRetrieveTranscript:
142
- if get_manual_if_fail:
143
- print('fallback')
144
- return get_words(video_id, process, fallback, 'manual')
145
-
146
- except json.decoder.JSONDecodeError:
147
- # Warning, unable to parse JSON
148
  pass
 
 
 
 
149
 
 
150
  with open(transcript_path, 'w') as fp:
151
  json.dump(words, fp)
152
 
 
 
 
153
  return words
154
 
155
 
156
  # TODO make min_sponsor_segment_length param
157
- def extract_sponsors(words, min_sponsor_segment_length=5):
158
- if len(words) < min_sponsor_segment_length:
159
- return [] # Force short phrases to not be sponsors
160
 
161
  paragraphs = []
162
  current = []
163
  prev_category = None
164
- for word in words:
165
- if word['category'] is None: # and not current:
166
- continue # Skip unimportant
167
 
168
- if word['category'] == prev_category:
169
- current.append(word['text'])
170
- else:
171
- paragraphs.append({
172
- 'words': current,
173
- 'category': prev_category,
174
- })
175
- current = []
176
 
177
- prev_category = word['category']
 
 
 
 
 
178
 
179
- if current and prev_category is not None:
180
- paragraphs.append({
181
- 'words': current,
182
- 'category': prev_category,
183
- })
184
 
185
- # Remove all too short:
186
- paragraphs = list(filter(lambda x: len(
187
- x['words']) >= min_sponsor_segment_length, paragraphs))
188
 
189
- return paragraphs
 
 
 
190
 
191
 
192
  def clean_text(text):
@@ -231,33 +231,27 @@ def clean_text(text):
231
  return text.strip()
232
 
233
 
234
- def remove_duplicate_sponsor_segments(sponsor_segments):
235
- """Choose the best sponsor segment if overlapping with others"""
236
-
237
  # Algorithm based on SponsorBlock algorithm
 
238
  # Find sponsors that are overlapping
239
- similar = []
240
- for i in sponsor_segments:
241
- for j in sponsor_segments:
242
- # Since we do pairwise, we only check one direction
243
- if (j['start'] >= i['start'] and j['start'] <= i['end']):
244
- similar.append([i, j])
245
-
246
- # Within each group, choose the segment with the most votes.
247
- processed = []
248
- best = []
249
- for i in similar:
250
- if i in processed:
251
- continue
252
- group = i
253
- for j in similar:
254
- if j[0] in group or j[1] in group: # If either in, append both
255
- group.append(j[0])
256
- group.append(j[1])
257
- processed.append(j)
258
 
259
- best.append(max(group, key=lambda item: (
260
- item['votes'], item['reputation'], item['views'])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  return best
263
 
@@ -280,16 +274,25 @@ class PreprocessArguments:
280
  # Downvotes will make this negative.
281
  # 1 = At least one positive vote
282
 
 
 
 
283
  min_date: str = field(
284
- default='20/08/2021', metadata={'help': 'Only use submissions from after this date, defaults to the release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)'})
 
 
 
 
285
 
 
286
  categories: str = field(
287
  default_factory=lambda: ['sponsor', 'selfpromo', 'interaction'],
288
  metadata={
289
  'nargs': '+',
290
- 'choices': ['intro', 'sponsor', 'interaction',
291
- 'outro', 'selfpromo', 'preview',
292
- 'poi_highlight', 'filler', 'music_offtopic'] # moreCategories
 
293
  }
294
  )
295
 
@@ -345,7 +348,7 @@ class PreprocessArguments:
345
  )
346
 
347
  min_wps: float = field(
348
- default=0.4, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'})
349
  # 0.1 ~ 1%
350
  # 0.4 ~ 2.5%
351
  # 0.9 ~ 5%
@@ -357,7 +360,7 @@ MIRRORS = [
357
  'https://sb-mirror.mchang.xyz/sponsorTimes.csv', # 5 minute delay
358
  'https://sb.ltn.fi/database/sponsorTimes.csv', # 5 minute delay
359
  ]
360
- # TODO only download latest (updates/changes)
361
 
362
 
363
  def download_file(url, filename):
@@ -480,7 +483,18 @@ def main():
480
  raw_dataset_path = os.path.join(
481
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
482
 
483
- def get_rows():
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  latest_time = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
486
 
@@ -488,10 +502,9 @@ def main():
488
  reader = csv.DictReader(csvfile)
489
 
490
  for line in reader:
491
- submitted_time = datetime.fromtimestamp(
492
- float(line['timeSubmitted'])/1e3)
493
 
494
- if submitted_time < latest_time:
495
  continue
496
 
497
  if line['service'] != 'YouTube':
@@ -499,7 +512,6 @@ def main():
499
  if len(line['videoID']) != 11:
500
  continue # Invalid youtube video ID
501
 
502
- # TODO add support for other categories and action types?
503
  if line['category'] not in preprocess_args.categories:
504
  continue
505
  if line['actionType'] != 'skip':
@@ -511,53 +523,72 @@ def main():
511
 
512
  # Skip those that aren't highly voted
513
  line['votes'] = int(line['votes'])
514
- # incorrect_votes = int(line['incorrectVotes'])
515
-
516
  if line['votes'] < preprocess_args.min_votes:
517
  continue
518
 
519
- yield line
520
 
521
- if preprocess_args.update_database:
522
- print('Updating database')
523
- for mirror in MIRRORS:
524
- print('Downloading from', mirror)
525
- if download_file(mirror, raw_dataset_path):
526
- break
527
- print('Failed, trying next')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  # 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
530
  # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
531
  # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
532
- data_rows = None
533
  if preprocess_args.do_transcribe:
534
  print('Collecting videos')
535
- video_ids = set()
536
- data_rows = get_rows()
537
- for row in data_rows:
538
- video_ids.add(row['videoID'])
539
 
540
- # TODO first set - os.listdir and do rest
541
-
542
- print('Start transcribing')
543
- with tqdm(total=len(video_ids)) as progress:
544
- def on_job_complete(job):
545
- progress.set_description(f'Processed {job.video_id}')
546
- progress.update()
547
 
548
- pool = InterruptibleThreadPool(
549
- preprocess_args.num_jobs, on_job_complete=on_job_complete)
550
 
551
- print('Adding jobs to pool')
552
- for video_id in video_ids:
553
- job = Job(get_words, video_id)
554
- job.video_id = video_id
555
- pool.add_job(job)
556
 
557
- print('Start processing')
558
- pool.run()
 
 
 
559
 
560
- print('Finished transcribing')
561
 
562
  final_path = os.path.join(
563
  processed_args.processed_dir, processed_args.processed_file)
@@ -567,56 +598,42 @@ def main():
567
 
568
  final_data = {}
569
 
570
- if data_rows is None:
571
- data_rows = get_rows()
572
- # data_rows = itertools.islice(data_rows, 1000) # TODO temp
573
 
574
  # TODO add progress bar
575
  # TODO parallelise?
576
- for index, line in enumerate(data_rows):
577
- video_id = line['videoID']
 
 
 
 
 
578
 
579
- if video_id not in final_data:
580
  final_data[video_id] = []
581
 
582
- segment_start = float(line['startTime'])
583
- segment_end = float(line['endTime'])
584
-
585
- video_words = get_words(video_id, process=False)
586
- if not video_words:
587
- continue
588
-
589
- segment_words = segment.extract_segment(
590
- video_words, segment_start, segment_end)
591
-
592
- if len(segment_words) <= 1:
593
- continue # Useless to add segment since no words
594
-
595
- # duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0])
596
- duration = segment_end - segment_start
597
- wps = len(segment_words)/duration if duration > 0 else 0
598
-
599
- if wps < preprocess_args.min_wps:
600
- print(index, 'Skipping bad segment in',
601
- video_id, '| wps =', wps)
602
- continue
603
-
604
- final_data[video_id].append({
605
- 'start': segment_start,
606
- 'end': segment_end,
607
- 'votes': line['votes'],
608
- 'locked': line['locked'] == '1',
609
- 'views': line['views'],
610
- 'reputation': line['reputation'],
611
- 'category': line['category'],
612
- 'action': line['actionType'],
613
- 'uuid': line['UUID'],
614
- })
615
 
616
- # Remove duplicate sponsor segments by choosing best (most votes)
617
- for key in final_data:
618
- final_data[key] = remove_duplicate_sponsor_segments(
619
- final_data[key])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
  # Save data
622
  with open(final_path, 'w') as fp:
@@ -656,8 +673,9 @@ def main():
656
 
657
  tokenizer = get_tokenizer(model_args)
658
 
659
- count_videos = 0
660
- count_segments = 0 # TODO
 
661
 
662
  write_mode = 'w' if preprocess_args.overwrite else 'a'
663
 
@@ -682,15 +700,15 @@ def main():
682
  open(negative_file, write_mode, encoding='utf-8') as negative, \
683
  tqdm(total=total) as progress:
684
 
685
- for video_id, sponsor_segments in data:
686
  index += 1 # TODO FIX index + incrementing
687
- progress.set_description(f'Processing {video_id}')
688
 
689
- if get_all:
690
- progress.update()
691
- elif count_videos >= preprocess_args.max_videos:
692
  break
693
 
 
 
 
694
  words = get_words(video_id, process=False)
695
  if not words:
696
  continue
@@ -707,16 +725,13 @@ def main():
707
  if not segments:
708
  continue
709
 
710
- count_videos += 1
711
- if not get_all:
712
- progress.update()
713
-
714
  for seg in segments:
715
  duration = segment.word_end(
716
  seg[-1]) - segment.word_start(seg[0])
717
  wps = len(seg)/duration if duration > 0 else 0
718
 
719
  # Ignore segments with "not enough words" in the transcript
 
720
  if wps < preprocess_args.min_wps:
721
  continue
722
 
@@ -732,13 +747,13 @@ def main():
732
  if extracted_segments:
733
  extracted_texts = []
734
  for s in extracted_segments:
735
- w = ' '.join(s['words'])
736
  category = s['category'].upper()
 
 
737
 
738
- t = f"{CustomTokens.START_SEGMENT.value}_{category} {w} {CustomTokens.END_SEGMENT.value}_{category}"
739
- extracted_texts.append(t)
740
-
741
- extracted_text = '\n'.join(extracted_texts)
742
 
743
  d['extracted'] = clean_text(extracted_text)
744
  print(json.dumps(d), file=positive)
 
1
+ from utils import jaccard
2
+ from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
3
+ from functools import lru_cache
4
  from datetime import datetime
5
  import itertools
6
  from typing import Optional, List
 
16
  import random
17
  import logging
18
  from youtube_transcript_api import YouTubeTranscriptApi
19
+ from youtube_transcript_api._errors import CouldNotRetrieveTranscript, YouTubeRequestFailed, TooManyRequests
20
  import os
21
  import json
22
  import time
23
  import requests
24
+ from utils import Task, InterruptibleTaskPool
25
 
26
 
27
  def find(s, ch):
 
109
  return words
110
 
111
 
112
+ def list_transcripts(video_id):
113
+ return YouTubeTranscriptApi.list_transcripts(video_id)
114
+
115
+
116
+ @lru_cache(maxsize=16)
117
  def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
118
  """Get parsed video transcript with caching system
119
  returns None if not processed yet and process is False
120
  """
121
  get_manual_if_fail = fallback and transcript_type == 'auto'
122
+ transcript_path = os.path.join( # TODO use relative path to this
123
  'transcripts', transcript_type, f'{video_id}.json')
124
  words = []
125
  try:
126
+ if os.path.exists(transcript_path): # Load from file
127
  with open(transcript_path) as fp:
128
+ words = json.load(fp)
 
 
 
 
129
 
130
+ elif process:
131
+ transcript_list = list_transcripts(video_id)
132
 
133
+ if transcript_type == 'manual':
134
+ words = get_manual_words(transcript_list)
135
+ else:
136
+ words = get_auto_words(transcript_list)
 
 
137
 
138
+ except (TooManyRequests, YouTubeRequestFailed, requests.exceptions.ConnectionError) as e: # Can retry
139
  print(e)
140
+ time.sleep(10) # Timeout
141
  return get_words(video_id, process, fallback, transcript_type)
142
 
143
  except CouldNotRetrieveTranscript:
 
 
 
 
 
 
144
  pass
145
+ except json.decoder.JSONDecodeError:
146
+ print('JSONDecodeError for', video_id)
147
+ os.remove(transcript_path) # Remove file and try again
148
+ return get_words(video_id, process, fallback, transcript_type)
149
 
150
+ # Even save empty
151
  with open(transcript_path, 'w') as fp:
152
  json.dump(words, fp)
153
 
154
+ if not words and get_manual_if_fail:
155
+ return get_words(video_id, process, fallback, 'manual')
156
+
157
  return words
158
 
159
 
160
  # TODO make min_sponsor_segment_length param
161
+ def extract_sponsors(words, min_sponsor_segment_length=3):
162
+ if not words:
163
+ return []
164
 
165
  paragraphs = []
166
  current = []
167
  prev_category = None
 
 
 
168
 
169
+ i = 0
170
+ while i <= len(words):
171
+ unimportant = i == len(words) or words[i]['category'] is None
 
 
 
 
 
172
 
173
+ if unimportant or words[i]['category'] != prev_category:
174
+ if current: # Save the current batch
175
+ paragraphs.append({
176
+ 'words': current,
177
+ 'category': current[-1]['category'],
178
+ })
179
 
180
+ current = []
 
 
 
 
181
 
182
+ if not unimportant: # Some useful information to save
183
+ current.append(words[i])
184
+ prev_category = words[i]['category']
185
 
186
+ i += 1
187
+
188
+ # Remove all too short:
189
+ return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs))
190
 
191
 
192
  def clean_text(text):
 
231
  return text.strip()
232
 
233
 
234
+ def remove_duplicate_segments(segments):
 
 
235
  # Algorithm based on SponsorBlock algorithm
236
+ # https://blog.ajay.app/voting-and-pseudo-randomness-or-sponsorblock-or-youtube-sponsorship-segment-blocker
237
  # Find sponsors that are overlapping
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ best = []
240
+ for i in segments:
241
+ similar_segments = []
242
+ for j in segments:
243
+ if jaccard(i['start'], i['end'], j['start'], j['end']) > 0.1: # Some overlap
244
+ similar_segments.append(j)
245
+
246
+ if similar_segments:
247
+ best_similar_seg = max(similar_segments, key=lambda item: (
248
+ item['locked'],
249
+ item['votes'],
250
+ item['views'],
251
+ item['reputation']
252
+ ))
253
+ if best_similar_seg not in best:
254
+ best.append(best_similar_seg)
255
 
256
  return best
257
 
 
274
  # Downvotes will make this negative.
275
  # 1 = At least one positive vote
276
 
277
+ min_views: int = field(
278
+ default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'})
279
+
280
  min_date: str = field(
281
+ # release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0)
282
+ default='08/06/2020',
283
+ # default='20/08/2021', # release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)
284
+ # default='01/10/2020', # No more autovote
285
+ metadata={'help': 'Only use submissions from after this date'})
286
 
287
+ # TODO move?
288
  categories: str = field(
289
  default_factory=lambda: ['sponsor', 'selfpromo', 'interaction'],
290
  metadata={
291
  'nargs': '+',
292
+ 'choices': ['intro', 'sponsor', 'interaction']
293
+ # 'outro', 'selfpromo', 'preview',
294
+ # 'poi_highlight', 'filler', 'music_offtopic',
295
+ # 'moreCategories'
296
  }
297
  )
298
 
 
348
  )
349
 
350
  min_wps: float = field(
351
+ default=1.5, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'})
352
  # 0.1 ~ 1%
353
  # 0.4 ~ 2.5%
354
  # 0.9 ~ 5%
 
360
  'https://sb-mirror.mchang.xyz/sponsorTimes.csv', # 5 minute delay
361
  'https://sb.ltn.fi/database/sponsorTimes.csv', # 5 minute delay
362
  ]
363
+ # TODO only download latest updates/changes
364
 
365
 
366
  def download_file(url, filename):
 
483
  raw_dataset_path = os.path.join(
484
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
485
 
486
+ if preprocess_args.update_database:
487
+ print('Updating database')
488
+ for mirror in MIRRORS:
489
+ print('Downloading from', mirror)
490
+ if download_file(mirror, raw_dataset_path):
491
+ break
492
+ print('Failed, trying next')
493
+
494
+ @lru_cache
495
+ def read_db(): # TODO save as file
496
+ print('Parsing raw database')
497
+ db = {}
498
 
499
  latest_time = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
500
 
 
502
  reader = csv.DictReader(csvfile)
503
 
504
  for line in reader:
505
+ submission_time = float(line['timeSubmitted'])/1e3
 
506
 
507
+ if datetime.fromtimestamp(submission_time) < latest_time:
508
  continue
509
 
510
  if line['service'] != 'YouTube':
 
512
  if len(line['videoID']) != 11:
513
  continue # Invalid youtube video ID
514
 
 
515
  if line['category'] not in preprocess_args.categories:
516
  continue
517
  if line['actionType'] != 'skip':
 
523
 
524
  # Skip those that aren't highly voted
525
  line['votes'] = int(line['votes'])
 
 
526
  if line['votes'] < preprocess_args.min_votes:
527
  continue
528
 
529
+ locked = line['locked'] == '1'
530
 
531
+ # Skip segments with low views (i.e., not really reviewed)
532
+ # Always include segments locked by VIPs, regardless of view count
533
+ line['views'] = int(line['views'])
534
+ if not locked and line['views'] < preprocess_args.min_views:
535
+ continue
536
+
537
+ if line['videoID'] not in db:
538
+ db[line['videoID']] = []
539
+
540
+ db[line['videoID']].append({
541
+ 'uuid': line['UUID'],
542
+ 'start': float(line['startTime']),
543
+ 'end': float(line['endTime']),
544
+ 'votes': line['votes'],
545
+ 'locked': locked,
546
+ 'views': line['views'],
547
+ 'submission_time': submission_time,
548
+ 'reputation': line['reputation'],
549
+ 'category': line['category'],
550
+ 'action': line['actionType'],
551
+ })
552
+
553
+ num_segments = 0
554
+
555
+ # Remove duplicate sponsor segments by choosing best (most votes)
556
+ print('Remove duplicate segments')
557
+ for key in db:
558
+ db[key] = remove_duplicate_segments(db[key])
559
+ num_segments += len(db[key])
560
+ print('Saved', len(db), 'videos and', num_segments, 'segments')
561
+
562
+ return db
563
 
564
  # 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
565
  # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
566
  # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
567
+ parsed_database = None
568
  if preprocess_args.do_transcribe:
569
  print('Collecting videos')
570
+ parsed_database = read_db()
 
 
 
571
 
572
+ # Remove transcripts already processed
573
+ finished = set(os.listdir('transcripts/auto/') +
574
+ os.listdir('transcripts/manual/'))
575
+ finished = set([x.split('.')[0] for x in finished])
 
 
 
576
 
577
+ video_ids = list(parsed_database.keys() - finished)
 
578
 
579
+ # Create tasks generator
580
+ tasks = (
581
+ Task(get_words, video_id)
582
+ for video_id in video_ids
583
+ )
584
 
585
+ print('start')
586
+ with tqdm(total=len(video_ids)) as progress:
587
+ def callback(task):
588
+ progress.set_description(f'Processing {task.args[0]}')
589
+ progress.update()
590
 
591
+ InterruptibleTaskPool(tasks, preprocess_args.num_jobs, callback).start()
592
 
593
  final_path = os.path.join(
594
  processed_args.processed_dir, processed_args.processed_file)
 
598
 
599
  final_data = {}
600
 
601
+ parsed_database = read_db()
 
 
602
 
603
  # TODO add progress bar
604
  # TODO parallelise?
605
+ with tqdm(total=len(parsed_database)) as progress:
606
+ for index, (video_id, segments) in enumerate(parsed_database.items()):
607
+
608
+ if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
609
+ break
610
+ progress.set_description(f'Processing {video_id}')
611
+ progress.update()
612
 
 
613
  final_data[video_id] = []
614
 
615
+ video_words = get_words(video_id, process=False)
616
+ if not video_words:
617
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
 
619
+ for seg in segments: # Only add segments with high enough wps
620
+ segment_words = segment.extract_segment(
621
+ video_words, seg['start'], seg['end'])
622
+
623
+ if len(segment_words) <= 1:
624
+ continue # Useless to add segment since no words
625
+
626
+ # duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0])
627
+ duration = seg['end'] - seg['start']
628
+ wps = len(segment_words)/duration if duration > 0 else 0
629
+
630
+ # print(video_id, wps)
631
+ if wps < preprocess_args.min_wps:
632
+ # Skip sponsor segments without many words
633
+ # e.g. music ads with some words on each side
634
+ # progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})')
635
+ continue
636
+ final_data[video_id].append(seg)
637
 
638
  # Save data
639
  with open(final_path, 'w') as fp:
 
673
 
674
  tokenizer = get_tokenizer(model_args)
675
 
676
+ # TODO
677
+ # count_videos = 0
678
+ # count_segments = 0
679
 
680
  write_mode = 'w' if preprocess_args.overwrite else 'a'
681
 
 
700
  open(negative_file, write_mode, encoding='utf-8') as negative, \
701
  tqdm(total=total) as progress:
702
 
703
+ for ind, (video_id, sponsor_segments) in enumerate(data):
704
  index += 1 # TODO FIX index + incrementing
 
705
 
706
+ if preprocess_args.max_videos is not None and ind >= preprocess_args.max_videos:
 
 
707
  break
708
 
709
+ progress.set_description(f'Processing {video_id}')
710
+ progress.update()
711
+
712
  words = get_words(video_id, process=False)
713
  if not words:
714
  continue
 
725
  if not segments:
726
  continue
727
 
 
 
 
 
728
  for seg in segments:
729
  duration = segment.word_end(
730
  seg[-1]) - segment.word_start(seg[0])
731
  wps = len(seg)/duration if duration > 0 else 0
732
 
733
  # Ignore segments with "not enough words" in the transcript
734
+ # Must do here since this includes non-sponsor segments
735
  if wps < preprocess_args.min_wps:
736
  continue
737
 
 
747
  if extracted_segments:
748
  extracted_texts = []
749
  for s in extracted_segments:
750
+ w = ' '.join([q['text'] for q in s['words']])
751
  category = s['category'].upper()
752
+ extracted_texts.append(
753
+ f"{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}")
754
 
755
+ extracted_text = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
756
+ extracted_texts)
 
 
757
 
758
  d['extracted'] = clean_text(extracted_text)
759
  print(json.dumps(d), file=positive)