Joshua Lochner commited on
Commit
4b4c9f0
1 Parent(s): 8b71088

Optimize segment generation and extraction

Browse files
Files changed (1) hide show
  1. src/segment.py +36 -24
src/segment.py CHANGED
@@ -52,12 +52,14 @@ def word_end(word):
52
  def generate_segments(words, tokenizer, segmentation_args):
53
  first_pass_segments = []
54
 
55
- for index, word in enumerate(words):
56
- # Get length of tokenized word
57
- word['cleaned'] = preprocess.clean_text(word['text'])
58
- word['num_tokens'] = len(
59
- tokenizer(word['cleaned'], add_special_tokens=False, truncation=True).input_ids)
60
-
 
 
61
  # Add new segment
62
  if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
63
  first_pass_segments.append([word])
@@ -78,7 +80,7 @@ def generate_segments(words, tokenizer, segmentation_args):
78
 
79
  for word in segment:
80
  new_seg = current_segment_num_tokens + \
81
- word['num_tokens'] >= max_q_size
82
  if new_seg:
83
  # Adding this token would make it have too many tokens
84
  # We save this batch and create new
@@ -86,7 +88,7 @@ def generate_segments(words, tokenizer, segmentation_args):
86
 
87
  # Add tokens to current segment
88
  current_segment.append(word)
89
- current_segment_num_tokens += word['num_tokens']
90
 
91
  if not new_seg:
92
  continue
@@ -94,7 +96,7 @@ def generate_segments(words, tokenizer, segmentation_args):
94
  # Just created a new segment, so we remove until we only have buffer_size tokens
95
  last_index = 0
96
  while current_segment_num_tokens > buffer_size and current_segment:
97
- current_segment_num_tokens -= current_segment[last_index]['num_tokens']
98
  last_index += 1
99
 
100
  current_segment = current_segment[last_index:]
@@ -102,19 +104,14 @@ def generate_segments(words, tokenizer, segmentation_args):
102
  if current_segment: # Add remaining segment
103
  second_pass_segments.append(current_segment)
104
 
105
- # Cleaning up, delete 'num_tokens' from each word
106
- # for segment in second_pass_segments:
107
- for word in words:
108
- word.pop('num_tokens', None)
109
-
110
  return second_pass_segments
111
 
112
 
113
  def extract_segment(words, start, end, map_function=None):
114
  """Extracts all words with time in [start, end]"""
115
 
116
- a = binary_search(words, 0, len(words), start, True)
117
- b = min(binary_search(words, 0, len(words), end, False) + 1, len(words))
118
 
119
  to_transform = map_function is not None and callable(map_function)
120
 
@@ -123,18 +120,33 @@ def extract_segment(words, start, end, map_function=None):
123
  ]
124
 
125
 
126
- def binary_search(words, start_index, end_index, time, below):
127
- """Binary search to get first index of word whose start/end time is greater/less than some value"""
 
 
 
128
  if start_index >= end_index:
129
  return end_index
130
 
131
  middle_index = (start_index + end_index) // 2
 
 
132
 
133
- middle_time = word_start(
134
- words[middle_index]) if below else word_end(words[middle_index])
135
-
136
- # TODO if above: if time < middle_time binary_search(start, middle-1)
137
  if time <= middle_time:
138
- return binary_search(words, start_index, middle_index, time, below)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  else:
140
- return binary_search(words, middle_index + 1, end_index, time, below)
 
52
  def generate_segments(words, tokenizer, segmentation_args):
53
  first_pass_segments = []
54
 
55
+ cleaned_words_list = []
56
+ for w in words:
57
+ w['cleaned'] = preprocess.clean_text(w['text'])
58
+ cleaned_words_list.append(w['cleaned'])
59
+
60
+ # Get lengths of tokenized words
61
+ num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False, truncation=True, return_attention_mask=False, return_length=True).length
62
+ for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
63
  # Add new segment
64
  if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
65
  first_pass_segments.append([word])
 
80
 
81
  for word in segment:
82
  new_seg = current_segment_num_tokens + \
83
+ num_tokens >= max_q_size
84
  if new_seg:
85
  # Adding this token would make it have too many tokens
86
  # We save this batch and create new
 
88
 
89
  # Add tokens to current segment
90
  current_segment.append(word)
91
+ current_segment_num_tokens += num_tokens
92
 
93
  if not new_seg:
94
  continue
 
96
  # Just created a new segment, so we remove until we only have buffer_size tokens
97
  last_index = 0
98
  while current_segment_num_tokens > buffer_size and current_segment:
99
+ current_segment_num_tokens -= num_tokens_list[last_index]
100
  last_index += 1
101
 
102
  current_segment = current_segment[last_index:]
 
104
  if current_segment: # Add remaining segment
105
  second_pass_segments.append(current_segment)
106
 
 
 
 
 
 
107
  return second_pass_segments
108
 
109
 
110
  def extract_segment(words, start, end, map_function=None):
111
  """Extracts all words with time in [start, end]"""
112
 
113
+ a = binary_search_below(words, 0, len(words) - 1, start)
114
+ b = min(binary_search_above(words, 0, len(words) - 1, end) + 1 , len(words))
115
 
116
  to_transform = map_function is not None and callable(map_function)
117
 
 
120
  ]
121
 
122
 
123
+ def avg(*items):
124
+ return sum(items)/len(items)
125
+
126
+
127
+ def binary_search_below(transcript, start_index, end_index, time):
128
  if start_index >= end_index:
129
  return end_index
130
 
131
  middle_index = (start_index + end_index) // 2
132
+ middle = transcript[middle_index]
133
+ middle_time = avg(middle['start'], middle['end'])
134
 
 
 
 
 
135
  if time <= middle_time:
136
+ return binary_search_below(transcript, start_index, middle_index, time)
137
+ else:
138
+ return binary_search_below(transcript, middle_index + 1, end_index, time)
139
+
140
+
141
+ def binary_search_above(transcript, start_index, end_index, time):
142
+ if start_index >= end_index:
143
+ return end_index
144
+
145
+ middle_index = (start_index + end_index + 1) // 2
146
+ middle = transcript[middle_index]
147
+ middle_time = avg(middle['start'], middle['end'])
148
+
149
+ if time >= middle_time:
150
+ return binary_search_above(transcript, middle_index, end_index, time)
151
  else:
152
+ return binary_search_above(transcript, start_index, middle_index - 1, time)