ernie_demo_toy / ernie /split_strategies.py
Jean Garcia-Gathright
added ernie files
a02c788
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import re
class RegexExpressions:
split_by_dot = re.compile(r'[^.]+(?:\.\s*)?')
split_by_semicolon = re.compile(r'[^;]+(?:\;\s*)?')
split_by_colon = re.compile(r'[^:]+(?:\:\s*)?')
split_by_comma = re.compile(r'[^,]+(?:\,\s*)?')
url = re.compile(
r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}'
r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
)
domain = re.compile(r'\w+\.\w+')
class SplitStrategy:
def __init__(
self,
split_patterns,
remove_patterns=None,
group_splits=True,
remove_too_short_groups=True
):
if not isinstance(split_patterns, list):
self.split_patterns = [split_patterns]
else:
self.split_patterns = split_patterns
if remove_patterns is not None \
and not isinstance(remove_patterns, list):
self.remove_patterns = [remove_patterns]
else:
self.remove_patterns = remove_patterns
self.group_splits = group_splits
self.remove_too_short_groups = remove_too_short_groups
def split(self, text, tokenizer, split_patterns=None):
if split_patterns is None:
if self.split_patterns is None:
return [text]
split_patterns = self.split_patterns
def len_in_tokens(text_):
no_tokens = len(tokenizer.encode(text_, add_special_tokens=False))
return no_tokens
no_special_tokens = len(tokenizer.encode('', add_special_tokens=True))
max_tokens = tokenizer.max_len - no_special_tokens
if self.remove_patterns is not None:
for remove_pattern in self.remove_patterns:
text = re.sub(remove_pattern, '', text).strip()
if len_in_tokens(text) <= max_tokens:
return [text]
selected_splits = []
splits = map(lambda x: x.strip(), re.findall(split_patterns[0], text))
aggregated_splits = ''
for split in splits:
if len_in_tokens(split) > max_tokens:
if len(split_patterns) > 1:
sub_splits = self.split(
split, tokenizer, split_patterns[1:])
selected_splits.extend(sub_splits)
else:
selected_splits.append(split)
else:
if not self.group_splits:
selected_splits.append(split)
else:
new_aggregated_splits = \
f'{aggregated_splits} {split}'.strip()
if len_in_tokens(new_aggregated_splits) <= max_tokens:
aggregated_splits = new_aggregated_splits
else:
selected_splits.append(aggregated_splits)
aggregated_splits = split
if aggregated_splits:
selected_splits.append(aggregated_splits)
remove_too_short_groups = len(selected_splits) > 1 \
and self.group_splits \
and self.remove_too_short_groups
if not remove_too_short_groups:
final_splits = selected_splits
else:
final_splits = []
min_length = tokenizer.max_len / 2
for split in selected_splits:
if len_in_tokens(split) >= min_length:
final_splits.append(split)
return final_splits
class SplitStrategies:
SentencesWithoutUrls = SplitStrategy(split_patterns=[
RegexExpressions.split_by_dot,
RegexExpressions.split_by_semicolon,
RegexExpressions.split_by_colon,
RegexExpressions.split_by_comma
],
remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
remove_too_short_groups=False,
group_splits=False)
GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[
RegexExpressions.split_by_dot,
RegexExpressions.split_by_semicolon,
RegexExpressions.split_by_colon,
RegexExpressions.split_by_comma
],
remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
remove_too_short_groups=True,
group_splits=True)