Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
4b4c9f0
1
Parent(s):
8b71088
Optimize segment generation and extraction
Browse files- src/segment.py +36 -24
src/segment.py
CHANGED
@@ -52,12 +52,14 @@ def word_end(word):
|
|
52 |
def generate_segments(words, tokenizer, segmentation_args):
|
53 |
first_pass_segments = []
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
# Add new segment
|
62 |
if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
63 |
first_pass_segments.append([word])
|
@@ -78,7 +80,7 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
78 |
|
79 |
for word in segment:
|
80 |
new_seg = current_segment_num_tokens + \
|
81 |
-
|
82 |
if new_seg:
|
83 |
# Adding this token would make it have too many tokens
|
84 |
# We save this batch and create new
|
@@ -86,7 +88,7 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
86 |
|
87 |
# Add tokens to current segment
|
88 |
current_segment.append(word)
|
89 |
-
current_segment_num_tokens +=
|
90 |
|
91 |
if not new_seg:
|
92 |
continue
|
@@ -94,7 +96,7 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
94 |
# Just created a new segment, so we remove until we only have buffer_size tokens
|
95 |
last_index = 0
|
96 |
while current_segment_num_tokens > buffer_size and current_segment:
|
97 |
-
current_segment_num_tokens -=
|
98 |
last_index += 1
|
99 |
|
100 |
current_segment = current_segment[last_index:]
|
@@ -102,19 +104,14 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
102 |
if current_segment: # Add remaining segment
|
103 |
second_pass_segments.append(current_segment)
|
104 |
|
105 |
-
# Cleaning up, delete 'num_tokens' from each word
|
106 |
-
# for segment in second_pass_segments:
|
107 |
-
for word in words:
|
108 |
-
word.pop('num_tokens', None)
|
109 |
-
|
110 |
return second_pass_segments
|
111 |
|
112 |
|
113 |
def extract_segment(words, start, end, map_function=None):
|
114 |
"""Extracts all words with time in [start, end]"""
|
115 |
|
116 |
-
a =
|
117 |
-
b = min(
|
118 |
|
119 |
to_transform = map_function is not None and callable(map_function)
|
120 |
|
@@ -123,18 +120,33 @@ def extract_segment(words, start, end, map_function=None):
|
|
123 |
]
|
124 |
|
125 |
|
126 |
-
def
|
127 |
-
|
|
|
|
|
|
|
128 |
if start_index >= end_index:
|
129 |
return end_index
|
130 |
|
131 |
middle_index = (start_index + end_index) // 2
|
|
|
|
|
132 |
|
133 |
-
middle_time = word_start(
|
134 |
-
words[middle_index]) if below else word_end(words[middle_index])
|
135 |
-
|
136 |
-
# TODO if above: if time < middle_time binary_search(start, middle-1)
|
137 |
if time <= middle_time:
|
138 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
else:
|
140 |
-
return
|
|
|
52 |
def generate_segments(words, tokenizer, segmentation_args):
|
53 |
first_pass_segments = []
|
54 |
|
55 |
+
cleaned_words_list = []
|
56 |
+
for w in words:
|
57 |
+
w['cleaned'] = preprocess.clean_text(w['text'])
|
58 |
+
cleaned_words_list.append(w['cleaned'])
|
59 |
+
|
60 |
+
# Get lengths of tokenized words
|
61 |
+
num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False, truncation=True, return_attention_mask=False, return_length=True).length
|
62 |
+
for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
|
63 |
# Add new segment
|
64 |
if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
65 |
first_pass_segments.append([word])
|
|
|
80 |
|
81 |
for word in segment:
|
82 |
new_seg = current_segment_num_tokens + \
|
83 |
+
num_tokens >= max_q_size
|
84 |
if new_seg:
|
85 |
# Adding this token would make it have too many tokens
|
86 |
# We save this batch and create new
|
|
|
88 |
|
89 |
# Add tokens to current segment
|
90 |
current_segment.append(word)
|
91 |
+
current_segment_num_tokens += num_tokens
|
92 |
|
93 |
if not new_seg:
|
94 |
continue
|
|
|
96 |
# Just created a new segment, so we remove until we only have buffer_size tokens
|
97 |
last_index = 0
|
98 |
while current_segment_num_tokens > buffer_size and current_segment:
|
99 |
+
current_segment_num_tokens -= num_tokens_list[last_index]
|
100 |
last_index += 1
|
101 |
|
102 |
current_segment = current_segment[last_index:]
|
|
|
104 |
if current_segment: # Add remaining segment
|
105 |
second_pass_segments.append(current_segment)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
107 |
return second_pass_segments
|
108 |
|
109 |
|
110 |
def extract_segment(words, start, end, map_function=None):
|
111 |
"""Extracts all words with time in [start, end]"""
|
112 |
|
113 |
+
a = binary_search_below(words, 0, len(words) - 1, start)
|
114 |
+
b = min(binary_search_above(words, 0, len(words) - 1, end) + 1 , len(words))
|
115 |
|
116 |
to_transform = map_function is not None and callable(map_function)
|
117 |
|
|
|
120 |
]
|
121 |
|
122 |
|
123 |
+
def avg(*items):
|
124 |
+
return sum(items)/len(items)
|
125 |
+
|
126 |
+
|
127 |
+
def binary_search_below(transcript, start_index, end_index, time):
|
128 |
if start_index >= end_index:
|
129 |
return end_index
|
130 |
|
131 |
middle_index = (start_index + end_index) // 2
|
132 |
+
middle = transcript[middle_index]
|
133 |
+
middle_time = avg(middle['start'], middle['end'])
|
134 |
|
|
|
|
|
|
|
|
|
135 |
if time <= middle_time:
|
136 |
+
return binary_search_below(transcript, start_index, middle_index, time)
|
137 |
+
else:
|
138 |
+
return binary_search_below(transcript, middle_index + 1, end_index, time)
|
139 |
+
|
140 |
+
|
141 |
+
def binary_search_above(transcript, start_index, end_index, time):
|
142 |
+
if start_index >= end_index:
|
143 |
+
return end_index
|
144 |
+
|
145 |
+
middle_index = (start_index + end_index + 1) // 2
|
146 |
+
middle = transcript[middle_index]
|
147 |
+
middle_time = avg(middle['start'], middle['end'])
|
148 |
+
|
149 |
+
if time >= middle_time:
|
150 |
+
return binary_search_above(transcript, middle_index, end_index, time)
|
151 |
else:
|
152 |
+
return binary_search_above(transcript, start_index, middle_index - 1, time)
|