Joshua Lochner commited on
Commit
3af0cd0
·
1 Parent(s): b3b69aa

Use `itertools.islice` instead of custom slicing

Browse files
Files changed (1) hide show
  1. src/preprocess.py +21 -36
src/preprocess.py CHANGED
@@ -302,9 +302,9 @@ class PreprocessArguments:
302
  num_jobs: int = field(
303
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
304
 
305
- overwrite: bool = field(
306
- default=True, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
307
- )
308
 
309
  do_generate: bool = field(
310
  default=False, metadata={'help': 'Generate labelled data.'}
@@ -325,8 +325,8 @@ class PreprocessArguments:
325
  valid_split: float = field(
326
  default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
327
 
328
- skip_videos: int = field(default=None, metadata={
329
- 'help': 'Number of videos to skip. Set this to the latest video index to append to the current file'})
330
 
331
  max_videos: int = field(default=None, metadata={
332
  'help': 'Maximum number of videos to preprocess.'})
@@ -588,7 +588,8 @@ def main():
588
  progress.set_description(f'Processing {task.args[0]}')
589
  progress.update()
590
 
591
- InterruptibleTaskPool(tasks, preprocess_args.num_jobs, callback).start()
 
592
 
593
  final_path = os.path.join(
594
  processed_args.processed_dir, processed_args.processed_file)
@@ -675,36 +676,20 @@ def main():
675
 
676
  # TODO
677
  # count_videos = 0
678
- # count_segments = 0
679
-
680
- write_mode = 'w' if preprocess_args.overwrite else 'a'
681
-
682
- get_all = preprocess_args.max_videos is None
683
-
684
- total = len(final_data) if get_all else preprocess_args.max_videos
685
 
686
- index = 0
687
  data = final_data.items()
688
- if preprocess_args.skip_videos is not None:
689
- print('Skipping first', preprocess_args.skip_videos, 'videos')
690
- data = itertools.islice(data, preprocess_args.skip_videos, None)
691
- index = preprocess_args.skip_videos
692
 
693
- if get_all:
694
- total = max(0, total - preprocess_args.skip_videos)
695
- else:
696
- total = min(len(final_data) -
697
- preprocess_args.skip_videos, total)
698
 
699
- with open(positive_file, write_mode, encoding='utf-8') as positive, \
700
- open(negative_file, write_mode, encoding='utf-8') as negative, \
701
- tqdm(total=total) as progress:
702
 
703
- for ind, (video_id, sponsor_segments) in enumerate(data):
704
- index += 1 # TODO FIX index + incrementing
 
705
 
706
- if preprocess_args.max_videos is not None and ind >= preprocess_args.max_videos:
707
- break
708
 
709
  progress.set_description(f'Processing {video_id}')
710
  progress.update()
@@ -735,22 +720,22 @@ def main():
735
  if wps < preprocess_args.min_wps:
736
  continue
737
 
738
- segment_text = ' '.join((x['text'] for x in seg))
739
- extracted_segments = extract_sponsors(seg)
740
  d = {
741
- 'video_index': index,
742
  'video_id': video_id,
743
- 'text': clean_text(segment_text),
744
  'words_per_second': round(wps, 3),
745
  }
746
 
 
747
  if extracted_segments:
748
  extracted_texts = []
749
  for s in extracted_segments:
750
- w = ' '.join([q['text'] for q in s['words']])
751
  category = s['category'].upper()
752
  extracted_texts.append(
753
- f"{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}")
 
754
 
755
  extracted_text = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
756
  extracted_texts)
 
302
  num_jobs: int = field(
303
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
304
 
305
+ # append: bool = field(
306
+ # default=False, metadata={'help': 'Append to training, testing and validation data, if present.'}
307
+ # )
308
 
309
  do_generate: bool = field(
310
  default=False, metadata={'help': 'Generate labelled data.'}
 
325
  valid_split: float = field(
326
  default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
327
 
328
+ start_index: int = field(default=None, metadata={
329
+ 'help': 'Video to start at.'})
330
 
331
  max_videos: int = field(default=None, metadata={
332
  'help': 'Maximum number of videos to preprocess.'})
 
588
  progress.set_description(f'Processing {task.args[0]}')
589
  progress.update()
590
 
591
+ InterruptibleTaskPool(
592
+ tasks, preprocess_args.num_jobs, callback).start()
593
 
594
  final_path = os.path.join(
595
  processed_args.processed_dir, processed_args.processed_file)
 
676
 
677
  # TODO
678
  # count_videos = 0
679
+ # count_segments = 0
 
 
 
 
 
 
680
 
 
681
  data = final_data.items()
 
 
 
 
682
 
683
+ start_index = preprocess_args.start_index or 0
684
+ end_index = (preprocess_args.max_videos or len(data)) + start_index
 
 
 
685
 
686
+ data = list(itertools.islice(data, start_index, end_index))
 
 
687
 
688
+ with open(positive_file, 'a', encoding='utf-8') as positive, \
689
+ open(negative_file, 'a', encoding='utf-8') as negative, \
690
+ tqdm(data) as progress:
691
 
692
+ for offset, (video_id, sponsor_segments) in enumerate(data):
 
693
 
694
  progress.set_description(f'Processing {video_id}')
695
  progress.update()
 
720
  if wps < preprocess_args.min_wps:
721
  continue
722
 
 
 
723
  d = {
724
+ 'video_index': offset + start_index,
725
  'video_id': video_id,
726
+ 'text': clean_text(' '.join(x['text'] for x in seg)),
727
  'words_per_second': round(wps, 3),
728
  }
729
 
730
+ extracted_segments = extract_sponsors(seg)
731
  if extracted_segments:
732
  extracted_texts = []
733
  for s in extracted_segments:
734
+ w = ' '.join(q['text'] for q in s['words'])
735
  category = s['category'].upper()
736
  extracted_texts.append(
737
+ f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
738
+ )
739
 
740
  extracted_text = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
741
  extracted_texts)