Joshua Lochner commited on
Commit
b27b0d5
1 Parent(s): bce5ce9

Improve preprocessing and segmentation

Browse files
Files changed (2) hide show
  1. src/preprocess.py +140 -111
  2. src/segment.py +44 -56
src/preprocess.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from utils import jaccard
2
  from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
3
  from functools import lru_cache
@@ -113,19 +114,26 @@ def list_transcripts(video_id):
113
  return YouTubeTranscriptApi.list_transcripts(video_id)
114
 
115
 
 
 
 
 
 
 
 
116
  @lru_cache(maxsize=16)
117
- def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
118
  """Get parsed video transcript with caching system
119
  returns None if not processed yet and process is False
120
  """
121
- get_manual_if_fail = fallback and transcript_type == 'auto'
122
  transcript_path = os.path.join( # TODO use relative path to this
123
  'transcripts', transcript_type, f'{video_id}.json')
124
- words = []
 
125
  try:
126
  if os.path.exists(transcript_path): # Load from file
127
  with open(transcript_path) as fp:
128
- words = json.load(fp)
129
 
130
  elif process:
131
  transcript_list = list_transcripts(video_id)
@@ -138,52 +146,55 @@ def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
138
  except (TooManyRequests, YouTubeRequestFailed, requests.exceptions.ConnectionError) as e: # Can retry
139
  print(e)
140
  time.sleep(10) # Timeout
141
- return get_words(video_id, process, fallback, transcript_type)
142
 
143
  except CouldNotRetrieveTranscript:
144
  pass
 
145
  except json.decoder.JSONDecodeError:
146
  print('JSONDecodeError for', video_id)
147
  os.remove(transcript_path) # Remove file and try again
148
- return get_words(video_id, process, fallback, transcript_type)
 
 
 
 
 
149
 
150
- # Even save empty
151
- with open(transcript_path, 'w') as fp:
152
- json.dump(words, fp)
153
 
154
- if not words and get_manual_if_fail:
155
- return get_words(video_id, process, fallback, 'manual')
156
 
157
  return words
158
 
159
 
160
  # TODO make min_sponsor_segment_length param
 
161
  def extract_sponsors(words, min_sponsor_segment_length=3):
162
- if not words:
163
  return []
164
 
165
  paragraphs = []
166
  current = []
167
  prev_category = None
168
 
169
- i = 0
170
- while i <= len(words):
171
- unimportant = i == len(words) or words[i]['category'] is None
172
 
173
- if unimportant or words[i]['category'] != prev_category:
174
  if current: # Save the current batch
175
  paragraphs.append({
176
  'words': current,
177
- 'category': current[-1]['category'],
178
  })
179
 
180
  current = []
181
 
182
  if not unimportant: # Some useful information to save
183
  current.append(words[i])
184
- prev_category = words[i]['category']
185
-
186
- i += 1
187
 
188
  # Remove all too short:
189
  return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs))
@@ -277,24 +288,20 @@ class PreprocessArguments:
277
  min_views: int = field(
278
  default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'})
279
 
 
 
 
280
  min_date: str = field(
281
- # release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0)
282
- default='08/06/2020',
283
- # default='20/08/2021', # release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)
284
  # default='01/10/2020', # No more autovote
285
- metadata={'help': 'Only use submissions from after this date'})
286
 
287
- # TODO move?
288
- categories: str = field(
289
- default_factory=lambda: ['sponsor', 'selfpromo', 'interaction'],
290
- metadata={
291
- 'nargs': '+',
292
- 'choices': ['intro', 'sponsor', 'interaction']
293
- # 'outro', 'selfpromo', 'preview',
294
- # 'poi_highlight', 'filler', 'music_offtopic',
295
- # 'moreCategories'
296
- }
297
- )
298
 
299
  do_transcribe: bool = field(
300
  default=False, metadata={'help': 'Get transcripts for videos'}
@@ -302,9 +309,9 @@ class PreprocessArguments:
302
  num_jobs: int = field(
303
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
304
 
305
- # append: bool = field(
306
- # default=False, metadata={'help': 'Append to training, testing and validation data, if present.'}
307
- # )
308
 
309
  do_generate: bool = field(
310
  default=False, metadata={'help': 'Generate labelled data.'}
@@ -381,22 +388,6 @@ def download_file(url, filename):
381
  return total_bytes == os.path.getsize(filename)
382
 
383
 
384
- @dataclass
385
- class ProcessedArguments:
386
- processed_dir: Optional[str] = field(
387
- default='processed',
388
- metadata={
389
- 'help': 'Processed data directory'
390
- },
391
- )
392
- processed_file: Optional[str] = field(
393
- default='final.json',
394
- metadata={
395
- 'help': 'Processed data file'
396
- },
397
- )
398
-
399
-
400
  def load_datasets(dataset_args):
401
  print('Reading datasets')
402
  data_files = {}
@@ -411,7 +402,7 @@ def load_datasets(dataset_args):
411
  data_files['test'] = os.path.join(
412
  dataset_args.data_dir, dataset_args.test_file)
413
 
414
- return load_dataset('json', data_files=data_files)
415
 
416
 
417
  @dataclass
@@ -422,6 +413,18 @@ class DatasetArguments:
422
  'help': 'The directory which stores train, test and/or validation data.'
423
  },
424
  )
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  train_file: Optional[str] = field(
427
  default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
@@ -444,7 +447,12 @@ class DatasetArguments:
444
  'help': 'The excess segments left after the split'
445
  },
446
  )
447
-
 
 
 
 
 
448
  overwrite_cache: bool = field(
449
  default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
450
  )
@@ -472,13 +480,12 @@ def main():
472
  # Generate final.json from sponsorTimes.csv
473
  hf_parser = HfArgumentParser((
474
  PreprocessArguments,
475
- ProcessedArguments,
476
  DatasetArguments,
477
  segment.SegmentationArguments,
478
  ModelArguments,
479
  GeneralArguments
480
  ))
481
- preprocess_args, processed_args, dataset_args, segmentation_args, model_args, _ = hf_parser.parse_args_into_dataclasses()
482
 
483
  raw_dataset_path = os.path.join(
484
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
@@ -491,28 +498,28 @@ def main():
491
  break
492
  print('Failed, trying next')
493
 
494
- @lru_cache
495
- def read_db(): # TODO save as file
496
- print('Parsing raw database')
497
- db = {}
498
 
499
- latest_time = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
 
 
 
 
 
500
 
 
501
  with open(raw_dataset_path, newline='') as csvfile:
502
  reader = csv.DictReader(csvfile)
503
 
504
  for line in reader:
505
- submission_time = float(line['timeSubmitted'])/1e3
506
-
507
- if datetime.fromtimestamp(submission_time) < latest_time:
508
- continue
509
 
510
  if line['service'] != 'YouTube':
511
  continue
512
  if len(line['videoID']) != 11:
513
  continue # Invalid youtube video ID
514
 
515
- if line['category'] not in preprocess_args.categories:
516
  continue
517
  if line['actionType'] != 'skip':
518
  continue
@@ -522,17 +529,18 @@ def main():
522
  continue
523
 
524
  # Skip those that aren't highly voted
525
- line['votes'] = int(line['votes'])
526
- if line['votes'] < preprocess_args.min_votes:
527
  continue
528
 
529
  locked = line['locked'] == '1'
530
 
531
- # Skip segments with low views (i.e., not really reviewed)
532
- # Always include segments locked by VIPs, regardless of view count
533
- line['views'] = int(line['views'])
534
- if not locked and line['views'] < preprocess_args.min_views:
535
- continue
 
536
 
537
  if line['videoID'] not in db:
538
  db[line['videoID']] = []
@@ -541,15 +549,37 @@ def main():
541
  'uuid': line['UUID'],
542
  'start': float(line['startTime']),
543
  'end': float(line['endTime']),
544
- 'votes': line['votes'],
545
  'locked': locked,
546
- 'views': line['views'],
547
- 'submission_time': submission_time,
548
- 'reputation': line['reputation'],
549
  'category': line['category'],
550
- 'action': line['actionType'],
551
  })
552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  num_segments = 0
554
 
555
  # Remove duplicate sponsor segments by choosing best (most votes)
@@ -559,20 +589,21 @@ def main():
559
  num_segments += len(db[key])
560
  print('Saved', len(db), 'videos and', num_segments, 'segments')
561
 
 
 
 
562
  return db
563
 
564
  # 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
565
  # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
566
  # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
567
- parsed_database = None
568
  if preprocess_args.do_transcribe:
569
  print('Collecting videos')
570
  parsed_database = read_db()
571
 
572
  # Remove transcripts already processed
573
- finished = set(os.listdir('transcripts/auto/') +
574
- os.listdir('transcripts/manual/'))
575
- finished = set([x.split('.')[0] for x in finished])
576
 
577
  video_ids = list(parsed_database.keys() - finished)
578
 
@@ -592,7 +623,7 @@ def main():
592
  tasks, preprocess_args.num_jobs, callback).start()
593
 
594
  final_path = os.path.join(
595
- processed_args.processed_dir, processed_args.processed_file)
596
 
597
  if preprocess_args.do_create:
598
  print('Create final data')
@@ -601,22 +632,19 @@ def main():
601
 
602
  parsed_database = read_db()
603
 
604
- # TODO add progress bar
605
  # TODO parallelise?
606
  with tqdm(total=len(parsed_database)) as progress:
607
  for index, (video_id, segments) in enumerate(parsed_database.items()):
608
-
609
  if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
610
  break
611
  progress.set_description(f'Processing {video_id}')
612
  progress.update()
613
 
614
- final_data[video_id] = []
615
-
616
  video_words = get_words(video_id, process=False)
617
  if not video_words:
618
  continue
619
 
 
620
  for seg in segments: # Only add segments with high enough wps
621
  segment_words = segment.extract_segment(
622
  video_words, seg['start'], seg['end'])
@@ -634,7 +662,10 @@ def main():
634
  # e.g. music ads with some words on each side
635
  # progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})')
636
  continue
637
- final_data[video_id].append(seg)
 
 
 
638
 
639
  # Save data
640
  with open(final_path, 'w') as fp:
@@ -666,13 +697,12 @@ def main():
666
 
667
  if preprocess_args.do_generate:
668
  print('Generating')
669
- from model import get_tokenizer
670
-
671
  # max_videos=preprocess_args.max_videos,
672
  # max_segments=preprocess_args.max_segments,
673
  # , max_videos, max_segments
674
 
675
- tokenizer = get_tokenizer(model_args)
 
676
 
677
  # TODO
678
  # count_videos = 0
@@ -685,8 +715,9 @@ def main():
685
 
686
  data = list(itertools.islice(data, start_index, end_index))
687
 
688
- with open(positive_file, 'a', encoding='utf-8') as positive, \
689
- open(negative_file, 'a', encoding='utf-8') as negative, \
 
690
  tqdm(data) as progress:
691
 
692
  for offset, (video_id, sponsor_segments) in enumerate(data):
@@ -711,36 +742,36 @@ def main():
711
  continue
712
 
713
  for seg in segments:
714
- duration = segment.word_end(
715
- seg[-1]) - segment.word_start(seg[0])
716
- wps = len(seg)/duration if duration > 0 else 0
 
717
 
718
- # Ignore segments with "not enough words" in the transcript
719
- # Must do here since this includes non-sponsor segments
720
- if wps < preprocess_args.min_wps:
721
- continue
722
 
723
  d = {
724
  'video_index': offset + start_index,
725
  'video_id': video_id,
726
- 'text': clean_text(' '.join(x['text'] for x in seg)),
727
- 'words_per_second': round(wps, 3),
 
728
  }
729
 
730
  extracted_segments = extract_sponsors(seg)
731
  if extracted_segments:
732
  extracted_texts = []
733
  for s in extracted_segments:
734
- w = ' '.join(q['text'] for q in s['words'])
735
  category = s['category'].upper()
736
  extracted_texts.append(
737
  f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
738
  )
739
 
740
- extracted_text = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
741
  extracted_texts)
742
-
743
- d['extracted'] = clean_text(extracted_text)
744
  print(json.dumps(d), file=positive)
745
 
746
  else:
@@ -824,14 +855,12 @@ def main():
824
 
825
 
826
  def split(arr, ratios):
827
- """Split array according to ratios. Sum of ratios should be less than 1"""
828
-
829
  to_return = []
830
 
831
  cumulative_sum = 0
832
  for r in ratios:
833
  current = cumulative_sum
834
-
835
  cumulative_sum += r * len(arr)
836
  to_return.append(arr[int(current):int(cumulative_sum)])
837
 
 
1
+ from shared import CATGEGORY_OPTIONS
2
  from utils import jaccard
3
  from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
4
  from functools import lru_cache
 
114
  return YouTubeTranscriptApi.list_transcripts(video_id)
115
 
116
 
117
+ WORDS_TO_REMOVE = [
118
+ CustomTokens.MUSIC.value,
119
+ CustomTokens.APPLAUSE.value,
120
+ CustomTokens.LAUGHTER.value
121
+ ]
122
+
123
+
124
  @lru_cache(maxsize=16)
125
+ def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True):
126
  """Get parsed video transcript with caching system
127
  returns None if not processed yet and process is False
128
  """
 
129
  transcript_path = os.path.join( # TODO use relative path to this
130
  'transcripts', transcript_type, f'{video_id}.json')
131
+
132
+ words = None
133
  try:
134
  if os.path.exists(transcript_path): # Load from file
135
  with open(transcript_path) as fp:
136
+ words = json.load(fp) # May be empty
137
 
138
  elif process:
139
  transcript_list = list_transcripts(video_id)
 
146
  except (TooManyRequests, YouTubeRequestFailed, requests.exceptions.ConnectionError) as e: # Can retry
147
  print(e)
148
  time.sleep(10) # Timeout
149
+ return get_words(video_id, process, transcript_type, fallback)
150
 
151
  except CouldNotRetrieveTranscript:
152
  pass
153
+
154
  except json.decoder.JSONDecodeError:
155
  print('JSONDecodeError for', video_id)
156
  os.remove(transcript_path) # Remove file and try again
157
+ return get_words(video_id, process, transcript_type, fallback)
158
+
159
+ # Tried to process it, but it was empty...
160
+ if process and not os.path.exists(transcript_path):
161
+ with open(transcript_path, 'w') as fp:
162
+ json.dump(words, fp)
163
 
164
+ if not words and fallback is not None:
165
+ return get_words(video_id, process, transcript_type=fallback, fallback=None)
 
166
 
167
+ if words and filter_words_to_remove:
168
+ words = list(filter(lambda x: x['text'] not in WORDS_TO_REMOVE, words))
169
 
170
  return words
171
 
172
 
173
  # TODO make min_sponsor_segment_length param
174
+ # TODO rename to extract_segments
175
  def extract_sponsors(words, min_sponsor_segment_length=3):
176
+ if not words or len(words) < min_sponsor_segment_length:
177
  return []
178
 
179
  paragraphs = []
180
  current = []
181
  prev_category = None
182
 
183
+ for i in range(len(words) + 1):
184
+ unimportant = i == len(words) or words[i].get('category') is None
 
185
 
186
+ if unimportant or words[i].get('category') != prev_category:
187
  if current: # Save the current batch
188
  paragraphs.append({
189
  'words': current,
190
+ 'category': current[-1].get('category'),
191
  })
192
 
193
  current = []
194
 
195
  if not unimportant: # Some useful information to save
196
  current.append(words[i])
197
+ prev_category = words[i].get('category')
 
 
198
 
199
  # Remove all too short:
200
  return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs))
 
288
  min_views: int = field(
289
  default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'})
290
 
291
+ # min_reputation: int = field(
292
+ # default=0, metadata={'help': 'Minimum reputation a user must have for the segment to be included'})
293
+
294
  min_date: str = field(
295
+ # default='08/06/2020', # release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0)
296
+ # release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)
297
+ default='20/08/2021',
298
  # default='01/10/2020', # No more autovote
299
+ metadata={'help': 'Only use submissions from after this date (inclusive)'})
300
 
301
+ max_date: str = field(
302
+ # default='01/01/9999', # Include all
303
+ default='27/01/2022',
304
+ metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
 
 
 
 
 
 
 
305
 
306
  do_transcribe: bool = field(
307
  default=False, metadata={'help': 'Get transcripts for videos'}
 
309
  num_jobs: int = field(
310
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
311
 
312
+ overwrite: bool = field(
313
+ default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
314
+ )
315
 
316
  do_generate: bool = field(
317
  default=False, metadata={'help': 'Generate labelled data.'}
 
388
  return total_bytes == os.path.getsize(filename)
389
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  def load_datasets(dataset_args):
392
  print('Reading datasets')
393
  data_files = {}
 
402
  data_files['test'] = os.path.join(
403
  dataset_args.data_dir, dataset_args.test_file)
404
 
405
+ return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
406
 
407
 
408
  @dataclass
 
413
  'help': 'The directory which stores train, test and/or validation data.'
414
  },
415
  )
416
+ processed_file: Optional[str] = field(
417
+ default='segments.json',
418
+ metadata={
419
+ 'help': 'Processed data file'
420
+ },
421
+ )
422
+ processed_database: Optional[str] = field(
423
+ default='processed_database.json',
424
+ metadata={
425
+ 'help': 'Processed database file'
426
+ },
427
+ )
428
 
429
  train_file: Optional[str] = field(
430
  default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
 
447
  'help': 'The excess segments left after the split'
448
  },
449
  )
450
+ dataset_cache_dir: Optional[str] = field(
451
+ default=None,
452
+ metadata={
453
+ 'help': 'Where to store the cached datasets'
454
+ },
455
+ )
456
  overwrite_cache: bool = field(
457
  default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
458
  )
 
480
  # Generate final.json from sponsorTimes.csv
481
  hf_parser = HfArgumentParser((
482
  PreprocessArguments,
 
483
  DatasetArguments,
484
  segment.SegmentationArguments,
485
  ModelArguments,
486
  GeneralArguments
487
  ))
488
+ preprocess_args, dataset_args, segmentation_args, model_args, _ = hf_parser.parse_args_into_dataclasses()
489
 
490
  raw_dataset_path = os.path.join(
491
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
 
498
  break
499
  print('Failed, trying next')
500
 
501
+ processed_db_path = os.path.join(
502
+ dataset_args.data_dir, dataset_args.processed_database)
 
 
503
 
504
+ def read_db():
505
+ if not preprocess_args.overwrite and os.path.exists(processed_db_path):
506
+ with open(processed_db_path) as fp:
507
+ return json.load(fp)
508
+ print('Processing raw database')
509
+ db = {}
510
 
511
+ allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
512
  with open(raw_dataset_path, newline='') as csvfile:
513
  reader = csv.DictReader(csvfile)
514
 
515
  for line in reader:
 
 
 
 
516
 
517
  if line['service'] != 'YouTube':
518
  continue
519
  if len(line['videoID']) != 11:
520
  continue # Invalid youtube video ID
521
 
522
+ if line['category'] not in allowed_categories:
523
  continue
524
  if line['actionType'] != 'skip':
525
  continue
 
529
  continue
530
 
531
  # Skip those that aren't highly voted
532
+ votes = int(line['votes'])
533
+ if votes < preprocess_args.min_votes:
534
  continue
535
 
536
  locked = line['locked'] == '1'
537
 
538
+ reputation = float(line['reputation'])
539
+ # if reputation < preprocess_args.min_reputation:
540
+ # continue # TODO add back?
541
+ # Problems like mGVn1wCkBrE
542
+
543
+ # TODO ignore if over max_duration
544
 
545
  if line['videoID'] not in db:
546
  db[line['videoID']] = []
 
549
  'uuid': line['UUID'],
550
  'start': float(line['startTime']),
551
  'end': float(line['endTime']),
552
+ 'votes': votes,
553
  'locked': locked,
554
+ 'views': int(line['views']),
555
+ 'submission_time': float(line['timeSubmitted'])/1e3,
556
+ 'reputation': reputation,
557
  'category': line['category'],
558
+ # 'action': line['actionType'],
559
  })
560
 
561
+ # We now remove whole videos from the list
562
+ # Helps with obtaining "fully-labelled" videos
563
+ min_date = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
564
+ max_date = datetime.strptime(preprocess_args.max_date, '%d/%m/%Y')
565
+ for key in list(db):
566
+
567
+ if any(datetime.fromtimestamp(x['submission_time']) < min_date for x in db[key]):
568
+ # Remove videos where any of its segments were submitted before min_date
569
+ # (essentially removes videos uploaded before min_date)
570
+ # Prevents issues where some segments of a video are excluded
571
+ del db[key]
572
+ elif all(datetime.fromtimestamp(x['submission_time']) > max_date for x in db[key]):
573
+ # Remove videos where all of its segments were submitted after max_date
574
+ # (essentially removes videos uploaded after max_date)
575
+ # Allows for segments to be corrected for past videos
576
+ del db[key]
577
+ elif any(not x['locked'] and x['views'] < preprocess_args.min_views for x in db[key]):
578
+ # Remove videos where any of its non-locked segments do not have enough views
579
+ # (essentially skips videos that have not been fully watched/reviewed)
580
+ # Always include segments locked by VIPs, regardless of view count
581
+ del db[key]
582
+
583
  num_segments = 0
584
 
585
  # Remove duplicate sponsor segments by choosing best (most votes)
 
589
  num_segments += len(db[key])
590
  print('Saved', len(db), 'videos and', num_segments, 'segments')
591
 
592
+ with open(processed_db_path, 'w') as fp:
593
+ json.dump(db, fp)
594
+
595
  return db
596
 
597
  # 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
598
  # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
599
  # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
 
600
  if preprocess_args.do_transcribe:
601
  print('Collecting videos')
602
  parsed_database = read_db()
603
 
604
  # Remove transcripts already processed
605
+ finished = set(x.split('.')[0] for x in os.listdir(
606
+ 'transcripts/auto/') + os.listdir('transcripts/manual/'))
 
607
 
608
  video_ids = list(parsed_database.keys() - finished)
609
 
 
623
  tasks, preprocess_args.num_jobs, callback).start()
624
 
625
  final_path = os.path.join(
626
+ dataset_args.data_dir, dataset_args.processed_file)
627
 
628
  if preprocess_args.do_create:
629
  print('Create final data')
 
632
 
633
  parsed_database = read_db()
634
 
 
635
  # TODO parallelise?
636
  with tqdm(total=len(parsed_database)) as progress:
637
  for index, (video_id, segments) in enumerate(parsed_database.items()):
 
638
  if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
639
  break
640
  progress.set_description(f'Processing {video_id}')
641
  progress.update()
642
 
 
 
643
  video_words = get_words(video_id, process=False)
644
  if not video_words:
645
  continue
646
 
647
+ final_vid_segs = []
648
  for seg in segments: # Only add segments with high enough wps
649
  segment_words = segment.extract_segment(
650
  video_words, seg['start'], seg['end'])
 
662
  # e.g. music ads with some words on each side
663
  # progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})')
664
  continue
665
+ final_vid_segs.append(seg)
666
+
667
+ if final_vid_segs:
668
+ final_data[video_id] = final_vid_segs
669
 
670
  # Save data
671
  with open(final_path, 'w') as fp:
 
697
 
698
  if preprocess_args.do_generate:
699
  print('Generating')
 
 
700
  # max_videos=preprocess_args.max_videos,
701
  # max_segments=preprocess_args.max_segments,
702
  # , max_videos, max_segments
703
 
704
+ from model import get_model_tokenizer
705
+ model, tokenizer = get_model_tokenizer(model_args.model_name_or_path)
706
 
707
  # TODO
708
  # count_videos = 0
 
715
 
716
  data = list(itertools.islice(data, start_index, end_index))
717
 
718
+ write_mode = 'w' if preprocess_args.overwrite else 'a'
719
+ with open(positive_file, write_mode, encoding='utf-8') as positive, \
720
+ open(negative_file, write_mode, encoding='utf-8') as negative, \
721
  tqdm(data) as progress:
722
 
723
  for offset, (video_id, sponsor_segments) in enumerate(data):
 
742
  continue
743
 
744
  for seg in segments:
745
+ seg_start = segment.word_start(seg[0])
746
+ seg_end = segment.word_end(seg[-1])
747
+ # duration = seg_end - seg_start
748
+ # wps = len(seg)/duration if duration > 0 else 0
749
 
750
+ # # Ignore segments with "not enough words" in the transcript
751
+ # # Must do here since this includes non-sponsor segments
752
+ # if wps < preprocess_args.min_wps:
753
+ # continue
754
 
755
  d = {
756
  'video_index': offset + start_index,
757
  'video_id': video_id,
758
+ 'text': ' '.join(x['cleaned'] for x in seg),
759
+ 'start': seg_start,
760
+ 'end': seg_end,
761
  }
762
 
763
  extracted_segments = extract_sponsors(seg)
764
  if extracted_segments:
765
  extracted_texts = []
766
  for s in extracted_segments:
767
+ w = ' '.join(q['cleaned'] for q in s['words'])
768
  category = s['category'].upper()
769
  extracted_texts.append(
770
  f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
771
  )
772
 
773
+ d['extracted'] = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
774
  extracted_texts)
 
 
775
  print(json.dumps(d), file=positive)
776
 
777
  else:
 
855
 
856
 
857
  def split(arr, ratios):
858
+ """Split array according to ratios. Sum of ratios should be <= 1"""
 
859
  to_return = []
860
 
861
  cumulative_sum = 0
862
  for r in ratios:
863
  current = cumulative_sum
 
864
  cumulative_sum += r * len(arr)
865
  to_return.append(arr[int(current):int(cumulative_sum)])
866
 
src/segment.py CHANGED
@@ -5,27 +5,19 @@ from dataclasses import dataclass, field
5
 
6
  @dataclass
7
  class SegmentationArguments:
8
- pause_threshold: int = field(default=2, metadata={
9
  'help': 'When the time between words is greater than pause threshold, force into a new segment'})
10
 
11
 
12
- # WORDS TO ALWAYS HAVE ON THEIR OWN
13
- # always_split_re = re.compile(r'\[\w+\]')
14
- # e.g., [Laughter], [Applause], [Music]
15
- always_split = [
16
- CustomTokens.MUSIC.value,
17
- CustomTokens.APPLAUSE.value,
18
- CustomTokens.LAUGHTER.value
19
- ]
20
-
21
-
22
  def get_overlapping_chunks_of_tokens(tokens, size, overlap):
23
  for i in range(0, len(tokens), size-overlap+1):
24
  yield tokens[i:i+size]
25
 
26
 
27
- # Generate up to max_tokens - SAFETY_TOKENS
28
- SAFETY_TOKENS = 12
 
 
29
 
30
 
31
  # TODO play around with this?
@@ -34,15 +26,9 @@ OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
34
 
35
  def add_labels_to_words(words, sponsor_segments):
36
 
37
- # TODO binary search
38
- for word in words:
39
- word['category'] = None
40
- for sponsor_segment in sponsor_segments:
41
- if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
42
- word['category'] = sponsor_segment['category']
43
-
44
- # TODO use extract_segment with mapping function?
45
- # TODO remove sponsor segments that contain mostly empty space?
46
 
47
  return words
48
 
@@ -69,84 +55,86 @@ def generate_segments(words, tokenizer, segmentation_args):
69
 
70
  for index, word in enumerate(words):
71
  # Get length of tokenized word
72
- cleaned = preprocess.clean_text(word['text'])
73
  word['num_tokens'] = len(
74
- tokenizer(cleaned, add_special_tokens=False, truncation=True).input_ids)
75
-
76
- add_new_segment = index == 0
77
- if not add_new_segment:
78
 
79
- if word['text'] in always_split or words[index-1]['text'] in always_split:
80
- add_new_segment = True
81
-
82
- # Pause too small, do not split
83
- elif word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
84
- add_new_segment = True
85
-
86
- if add_new_segment: # New segment
87
  first_pass_segments.append([word])
88
 
89
  else: # Add to current segment
90
  first_pass_segments[-1].append(word)
91
 
92
- max_q_size = tokenizer.model_max_length - SAFETY_TOKENS
93
 
94
  buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
95
 
96
  # In second pass, we split those segments if too big
97
  second_pass_segments = []
 
98
  for segment in first_pass_segments:
99
  current_segment_num_tokens = 0
100
  current_segment = []
 
101
  for word in segment:
102
- new_seg = current_segment_num_tokens + word['num_tokens'] >= max_q_size
 
103
  if new_seg:
104
  # Adding this token would make it have too many tokens
105
  # We save this batch and create new
106
- second_pass_segments.append(current_segment.copy())
107
 
108
  # Add tokens to current segment
109
  current_segment.append(word)
110
  current_segment_num_tokens += word['num_tokens']
111
 
112
- if new_seg:
113
- # Just created a new segment, so we remove until we only have buffer_size tokens
114
- while current_segment_num_tokens > buffer_size and current_segment:
115
- first_word = current_segment.pop(0)
116
- current_segment_num_tokens -= first_word['num_tokens']
 
 
 
 
 
117
 
118
- if current_segment: # Add remaining segment
119
- second_pass_segments.append(current_segment.copy())
120
 
121
  # Cleaning up, delete 'num_tokens' from each word
122
- for segment in second_pass_segments:
123
- for word in segment:
124
- word.pop('num_tokens', None)
125
-
126
  return second_pass_segments
127
 
128
 
129
  def extract_segment(words, start, end, map_function=None):
130
  """Extracts all words with time in [start, end]"""
131
-
132
  a = binary_search(words, 0, len(words), start, True)
133
- b = min(binary_search(words, 0, len(words), end , False) + 1, len(words))
134
 
135
  to_transform = map_function is not None and callable(map_function)
136
-
137
  return [
138
  map_function(words[i]) if to_transform else words[i] for i in range(a, b)
139
  ]
140
 
141
- # Binary search to get first index of word whose start/end time is greater/less than some value
142
  def binary_search(words, start_index, end_index, time, below):
 
143
  if start_index >= end_index:
144
  return end_index
145
-
146
- middle_index = (start_index + end_index ) // 2
147
 
148
- middle_time = word_start(words[middle_index]) if below else word_end(words[middle_index])
 
 
 
149
 
 
150
  if time <= middle_time:
151
  return binary_search(words, start_index, middle_index, time, below)
152
  else:
 
5
 
6
  @dataclass
7
  class SegmentationArguments:
8
+ pause_threshold: int = field(default=2.5, metadata={
9
  'help': 'When the time between words is greater than pause threshold, force into a new segment'})
10
 
11
 
 
 
 
 
 
 
 
 
 
 
12
  def get_overlapping_chunks_of_tokens(tokens, size, overlap):
13
  for i in range(0, len(tokens), size-overlap+1):
14
  yield tokens[i:i+size]
15
 
16
 
17
+ # Generate up to SAFETY_TOKENS_PERCENTAGE*max_tokens tokens
18
+ MIN_SAFETY_TOKENS = 8
19
+ SAFETY_TOKENS_PERCENTAGE = 0.9765625
20
+ # e.g. 512 -> 500, 768 -> 750
21
 
22
 
23
  # TODO play around with this?
 
26
 
27
  def add_labels_to_words(words, sponsor_segments):
28
 
29
+ for sponsor_segment in sponsor_segments:
30
+ for w in extract_segment(words, sponsor_segment['start'], sponsor_segment['end']):
31
+ w['category'] = sponsor_segment['category']
 
 
 
 
 
 
32
 
33
  return words
34
 
 
55
 
56
  for index, word in enumerate(words):
57
  # Get length of tokenized word
58
+ word['cleaned'] = preprocess.clean_text(word['text'])
59
  word['num_tokens'] = len(
60
+ tokenizer(word['cleaned'], add_special_tokens=False, truncation=True).input_ids)
 
 
 
61
 
62
+ # Add new segment
63
+ if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
 
 
 
 
 
 
64
  first_pass_segments.append([word])
65
 
66
  else: # Add to current segment
67
  first_pass_segments[-1].append(word)
68
 
69
+ max_q_size = round(SAFETY_TOKENS_PERCENTAGE * tokenizer.model_max_length)
70
 
71
  buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
72
 
73
  # In second pass, we split those segments if too big
74
  second_pass_segments = []
75
+
76
  for segment in first_pass_segments:
77
  current_segment_num_tokens = 0
78
  current_segment = []
79
+
80
  for word in segment:
81
+ new_seg = current_segment_num_tokens + \
82
+ word['num_tokens'] >= max_q_size
83
  if new_seg:
84
  # Adding this token would make it have too many tokens
85
  # We save this batch and create new
86
+ second_pass_segments.append(current_segment)
87
 
88
  # Add tokens to current segment
89
  current_segment.append(word)
90
  current_segment_num_tokens += word['num_tokens']
91
 
92
+ if not new_seg:
93
+ continue
94
+
95
+ # Just created a new segment, so we remove until we only have buffer_size tokens
96
+ last_index = 0
97
+ while current_segment_num_tokens > buffer_size and current_segment:
98
+ current_segment_num_tokens -= current_segment[last_index]['num_tokens']
99
+ last_index += 1
100
+
101
+ current_segment = current_segment[last_index:]
102
 
103
+ if current_segment: # Add remaining segment
104
+ second_pass_segments.append(current_segment)
105
 
106
  # Cleaning up, delete 'num_tokens' from each word
107
+ # for segment in second_pass_segments:
108
+ for word in words:
109
+ word.pop('num_tokens', None)
110
+
111
  return second_pass_segments
112
 
113
 
114
  def extract_segment(words, start, end, map_function=None):
115
  """Extracts all words with time in [start, end]"""
116
+
117
  a = binary_search(words, 0, len(words), start, True)
118
+ b = min(binary_search(words, 0, len(words), end, False) + 1, len(words))
119
 
120
  to_transform = map_function is not None and callable(map_function)
121
+
122
  return [
123
  map_function(words[i]) if to_transform else words[i] for i in range(a, b)
124
  ]
125
 
126
+
127
  def binary_search(words, start_index, end_index, time, below):
128
+ """Binary search to get first index of word whose start/end time is greater/less than some value"""
129
  if start_index >= end_index:
130
  return end_index
 
 
131
 
132
+ middle_index = (start_index + end_index) // 2
133
+
134
+ middle_time = word_start(
135
+ words[middle_index]) if below else word_end(words[middle_index])
136
 
137
+ # TODO if above: if time < middle_time binary_search(start, middle-1)
138
  if time <= middle_time:
139
  return binary_search(words, start_index, middle_index, time, below)
140
  else: