Joshua Lochner commited on
Commit
2782b0c
1 Parent(s): 183ba5e

Fix the reduction of overlapping segments

Browse files
Files changed (1) hide show
  1. src/preprocess.py +14 -16
src/preprocess.py CHANGED
@@ -1,6 +1,4 @@
1
- from shared import CATGEGORY_OPTIONS
2
- from utils import jaccard
3
- from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
4
  from functools import lru_cache
5
  from datetime import datetime
6
  import itertools
@@ -11,18 +9,16 @@ import segment
11
  from tqdm import tqdm
12
  from dataclasses import dataclass, field
13
  from transformers import HfArgumentParser
14
- from shared import GeneralArguments, CustomTokens
15
  import csv
16
  import re
17
  import random
18
  import logging
19
- from youtube_transcript_api import YouTubeTranscriptApi
20
- from youtube_transcript_api._errors import CouldNotRetrieveTranscript, YouTubeRequestFailed, TooManyRequests
21
  import os
22
  import json
23
  import time
24
  import requests
25
- from utils import Task, InterruptibleTaskPool
26
 
27
 
28
  def find(s, ch):
@@ -264,6 +260,9 @@ def remove_duplicate_segments(segments):
264
  if best_similar_seg not in best:
265
  best.append(best_similar_seg)
266
 
 
 
 
267
  return best
268
 
269
 
@@ -501,6 +500,7 @@ def main():
501
  processed_db_path = os.path.join(
502
  dataset_args.data_dir, dataset_args.processed_database)
503
 
 
504
  def read_db():
505
  if not preprocess_args.overwrite and os.path.exists(processed_db_path):
506
  with open(processed_db_path) as fp:
@@ -558,6 +558,11 @@ def main():
558
  # 'action': line['actionType'],
559
  })
560
 
 
 
 
 
 
561
  # We now remove whole videos from the list
562
  # Helps with obtaining "fully-labelled" videos
563
  min_date = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
@@ -580,14 +585,7 @@ def main():
580
  # Always include segments locked by VIPs, regardless of view count
581
  del db[key]
582
 
583
- num_segments = 0
584
-
585
- # Remove duplicate sponsor segments by choosing best (most votes)
586
- print('Remove duplicate segments')
587
- for key in db:
588
- db[key] = remove_duplicate_segments(db[key])
589
- num_segments += len(db[key])
590
- print('Saved', len(db), 'videos and', num_segments, 'segments')
591
 
592
  with open(processed_db_path, 'w') as fp:
593
  json.dump(db, fp)
@@ -613,7 +611,7 @@ def main():
613
  for video_id in video_ids
614
  )
615
 
616
- print('start')
617
  with tqdm(total=len(video_ids)) as progress:
618
  def callback(task):
619
  progress.set_description(f'Processing {task.args[0]}')
 
1
+ from utils import jaccard, Task, InterruptibleTaskPool
 
 
2
  from functools import lru_cache
3
  from datetime import datetime
4
  import itertools
 
9
  from tqdm import tqdm
10
  from dataclasses import dataclass, field
11
  from transformers import HfArgumentParser
12
+ from shared import CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
13
  import csv
14
  import re
15
  import random
16
  import logging
17
+ from youtube_transcript_api import YouTubeTranscriptApi, CouldNotRetrieveTranscript, YouTubeRequestFailed, TooManyRequests
 
18
  import os
19
  import json
20
  import time
21
  import requests
 
22
 
23
 
24
  def find(s, ch):
 
260
  if best_similar_seg not in best:
261
  best.append(best_similar_seg)
262
 
263
+ if len(segments) != len(best): # Saw some reduction... try again
264
+ return remove_duplicate_segments(best)
265
+
266
  return best
267
 
268
 
 
500
  processed_db_path = os.path.join(
501
  dataset_args.data_dir, dataset_args.processed_database)
502
 
503
+ @lru_cache(maxsize=1)
504
  def read_db():
505
  if not preprocess_args.overwrite and os.path.exists(processed_db_path):
506
  with open(processed_db_path) as fp:
 
558
  # 'action': line['actionType'],
559
  })
560
 
561
+ # Remove duplicate sponsor segments by choosing best (most votes)
562
+ print('Remove duplicate segments')
563
+ for key in db:
564
+ db[key] = remove_duplicate_segments(db[key])
565
+
566
  # We now remove whole videos from the list
567
  # Helps with obtaining "fully-labelled" videos
568
  min_date = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
 
585
  # Always include segments locked by VIPs, regardless of view count
586
  del db[key]
587
 
588
+ print('Saved', len(db), 'videos')
 
 
 
 
 
 
 
589
 
590
  with open(processed_db_path, 'w') as fp:
591
  json.dump(db, fp)
 
611
  for video_id in video_ids
612
  )
613
 
614
+ print('Downloading transcripts')
615
  with tqdm(total=len(video_ids)) as progress:
616
  def callback(task):
617
  progress.set_description(f'Processing {task.args[0]}')