Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import re | |
class RegexExpressions: | |
split_by_dot = re.compile(r'[^.]+(?:\.\s*)?') | |
split_by_semicolon = re.compile(r'[^;]+(?:\;\s*)?') | |
split_by_colon = re.compile(r'[^:]+(?:\:\s*)?') | |
split_by_comma = re.compile(r'[^,]+(?:\,\s*)?') | |
url = re.compile( | |
r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}' | |
r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)' | |
) | |
domain = re.compile(r'\w+\.\w+') | |
class SplitStrategy: | |
def __init__( | |
self, | |
split_patterns, | |
remove_patterns=None, | |
group_splits=True, | |
remove_too_short_groups=True | |
): | |
if not isinstance(split_patterns, list): | |
self.split_patterns = [split_patterns] | |
else: | |
self.split_patterns = split_patterns | |
if remove_patterns is not None \ | |
and not isinstance(remove_patterns, list): | |
self.remove_patterns = [remove_patterns] | |
else: | |
self.remove_patterns = remove_patterns | |
self.group_splits = group_splits | |
self.remove_too_short_groups = remove_too_short_groups | |
def split(self, text, tokenizer, split_patterns=None): | |
if split_patterns is None: | |
if self.split_patterns is None: | |
return [text] | |
split_patterns = self.split_patterns | |
def len_in_tokens(text_): | |
no_tokens = len(tokenizer.encode(text_, add_special_tokens=False)) | |
return no_tokens | |
no_special_tokens = len(tokenizer.encode('', add_special_tokens=True)) | |
max_tokens = tokenizer.max_len - no_special_tokens | |
if self.remove_patterns is not None: | |
for remove_pattern in self.remove_patterns: | |
text = re.sub(remove_pattern, '', text).strip() | |
if len_in_tokens(text) <= max_tokens: | |
return [text] | |
selected_splits = [] | |
splits = map(lambda x: x.strip(), re.findall(split_patterns[0], text)) | |
aggregated_splits = '' | |
for split in splits: | |
if len_in_tokens(split) > max_tokens: | |
if len(split_patterns) > 1: | |
sub_splits = self.split( | |
split, tokenizer, split_patterns[1:]) | |
selected_splits.extend(sub_splits) | |
else: | |
selected_splits.append(split) | |
else: | |
if not self.group_splits: | |
selected_splits.append(split) | |
else: | |
new_aggregated_splits = \ | |
f'{aggregated_splits} {split}'.strip() | |
if len_in_tokens(new_aggregated_splits) <= max_tokens: | |
aggregated_splits = new_aggregated_splits | |
else: | |
selected_splits.append(aggregated_splits) | |
aggregated_splits = split | |
if aggregated_splits: | |
selected_splits.append(aggregated_splits) | |
remove_too_short_groups = len(selected_splits) > 1 \ | |
and self.group_splits \ | |
and self.remove_too_short_groups | |
if not remove_too_short_groups: | |
final_splits = selected_splits | |
else: | |
final_splits = [] | |
min_length = tokenizer.max_len / 2 | |
for split in selected_splits: | |
if len_in_tokens(split) >= min_length: | |
final_splits.append(split) | |
return final_splits | |
class SplitStrategies: | |
SentencesWithoutUrls = SplitStrategy(split_patterns=[ | |
RegexExpressions.split_by_dot, | |
RegexExpressions.split_by_semicolon, | |
RegexExpressions.split_by_colon, | |
RegexExpressions.split_by_comma | |
], | |
remove_patterns=[RegexExpressions.url, RegexExpressions.domain], | |
remove_too_short_groups=False, | |
group_splits=False) | |
GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[ | |
RegexExpressions.split_by_dot, | |
RegexExpressions.split_by_semicolon, | |
RegexExpressions.split_by_colon, | |
RegexExpressions.split_by_comma | |
], | |
remove_patterns=[RegexExpressions.url, RegexExpressions.domain], | |
remove_too_short_groups=True, | |
group_splits=True) | |