Spaces:
Runtime error
Runtime error
File size: 4,471 Bytes
5fbdd3c 90d1f68 5fbdd3c 90d1f68 5fbdd3c 90d1f68 5fbdd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import preprocess
from shared import CustomTokens
from dataclasses import dataclass, field
@dataclass
class SegmentationArguments:
pause_threshold: int = field(default=2, metadata={
'help': 'When the time between words is greater than pause threshold, force into a new segment'})
# WORDS TO ALWAYS HAVE ON THEIR OWN
# always_split_re = re.compile(r'\[\w+\]')
# e.g., [Laughter], [Applause], [Music]
always_split = [
CustomTokens.MUSIC.value,
CustomTokens.APPLAUSE.value,
CustomTokens.LAUGHTER.value
]
def get_overlapping_chunks_of_tokens(tokens, size, overlap):
for i in range(0, len(tokens), size-overlap+1):
yield tokens[i:i+size]
# Generate up to max_tokens - SAFETY_TOKENS
SAFETY_TOKENS = 12
# TODO play around with this?
OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
def add_labels_to_words(words, sponsor_segments):
# TODO binary search
for word in words:
word['category'] = None
for sponsor_segment in sponsor_segments:
if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
word['category'] = sponsor_segment['category']
# TODO use extract_segment with mapping function?
# TODO remove sponsor segments that contain mostly empty space?
return words
def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
segments = generate_segments(words, tokenizer, segmentation_args)
labelled_segments = list(
map(lambda x: add_labels_to_words(x, sponsor_segments), segments))
return labelled_segments
def word_start(word):
return word['start']
def word_end(word):
return word.get('end', word['start'])
def generate_segments(words, tokenizer, segmentation_args):
first_pass_segments = []
for index, word in enumerate(words):
# Get length of tokenized word
cleaned = preprocess.clean_text(word['text'])
word['num_tokens'] = len(
tokenizer(cleaned, add_special_tokens=False, truncation=True).input_ids)
add_new_segment = index == 0
if not add_new_segment:
if word['text'] in always_split or words[index-1]['text'] in always_split:
add_new_segment = True
# Pause too small, do not split
elif word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
add_new_segment = True
if add_new_segment: # New segment
first_pass_segments.append([word])
else: # Add to current segment
first_pass_segments[-1].append(word)
max_q_size = tokenizer.model_max_length - SAFETY_TOKENS
buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
# In second pass, we split those segments if too big
second_pass_segments = []
for segment in first_pass_segments:
current_segment_num_tokens = 0
current_segment = []
for word in segment:
if current_segment_num_tokens + word['num_tokens'] < max_q_size:
# Can add tokens to current segment
current_segment.append(word)
current_segment_num_tokens += word['num_tokens']
else:
# Adding this token would make it have too many tokens
# We save this batch and create new
second_pass_segments.append(current_segment.copy())
current_segment.append(word)
current_segment_num_tokens += word['num_tokens']
while current_segment_num_tokens > buffer_size and current_segment:
first_word = current_segment.pop(0)
current_segment_num_tokens -= first_word['num_tokens']
if current_segment:
second_pass_segments.append(current_segment.copy())
return second_pass_segments
def extract_segment(words, start, end, map_function=None):
"""Extract a segment of words that are between (inclusive) the start and end points"""
segment_words = []
if start > end:
return segment_words
# TODO change to binary search
for w in words: # Assumes words are sorted
if word_end(w) < start:
continue # Ignore
if word_start(w) > end:
break # Done with range
if map_function is not None and callable(map_function):
w = map_function(w)
segment_words.append(w)
return segment_words
|