|
import re |
|
from abc import ABC, abstractmethod |
|
from typing import Union |
|
|
|
import inflect |
|
import nltk |
|
from flair.data import Sentence |
|
from flair.models import SequenceTagger |
|
|
|
__all__ = [ |
|
"DropFileExtensions", |
|
"DropNonAlpha", |
|
"DropShortWords", |
|
"DropSpecialCharacters", |
|
"DropTokens", |
|
"DropURLs", |
|
"DropWords", |
|
"FilterPOS", |
|
"FrequencyMinWordCount", |
|
"ReplaceSeparators", |
|
"ToLowercase", |
|
"ToSingular", |
|
] |
|
|
|
|
|
class BaseTextTransform(ABC): |
|
"""Base class for string transforms.""" |
|
|
|
@abstractmethod |
|
def __call__(self, text: str) -> str: |
|
raise NotImplementedError |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}()" |
|
|
|
|
|
class DropFileExtensions(BaseTextTransform): |
|
"""Remove file extensions from the input text.""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove file extensions from. |
|
""" |
|
text = re.sub(r"\.\w+", "", text) |
|
|
|
return text |
|
|
|
|
|
class DropNonAlpha(BaseTextTransform): |
|
"""Remove non-alpha words from the input text.""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove non-alpha words from. |
|
""" |
|
text = re.sub(r"[^a-zA-Z\s]", "", text) |
|
|
|
return text |
|
|
|
|
|
class DropShortWords(BaseTextTransform): |
|
"""Remove short words from the input text. |
|
|
|
Args: |
|
min_length (int): Minimum length of words to keep. |
|
""" |
|
|
|
def __init__(self, min_length) -> None: |
|
super().__init__() |
|
self.min_length = min_length |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove short words from. |
|
""" |
|
text = " ".join([word for word in text.split() if len(word) >= self.min_length]) |
|
|
|
return text |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(min_length={self.min_length})" |
|
|
|
|
|
class DropSpecialCharacters(BaseTextTransform): |
|
"""Remove special characters from the input text. |
|
|
|
Special characters are defined as any character that is not a word character, whitespace, |
|
hyphen, period, apostrophe, or ampersand. |
|
""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove special characters from. |
|
""" |
|
text = re.sub(r"[^\w\s\-\.\'\&]", "", text) |
|
|
|
return text |
|
|
|
|
|
class DropTokens(BaseTextTransform): |
|
"""Remove tokens from the input text. |
|
|
|
Tokens are defined as strings enclosed in angle brackets, e.g. <token>. |
|
""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove tokens from. |
|
""" |
|
text = re.sub(r"<[^>]+>", "", text) |
|
|
|
return text |
|
|
|
|
|
class DropURLs(BaseTextTransform): |
|
"""Remove URLs from the input text.""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove URLs from. |
|
""" |
|
text = re.sub(r"http\S+", "", text) |
|
|
|
return text |
|
|
|
|
|
class DropWords(BaseTextTransform): |
|
"""Remove words from the input text. |
|
|
|
It is case-insensitive and supports singular and plural forms of the words. |
|
""" |
|
|
|
def __init__(self, words: list[str]) -> None: |
|
super().__init__() |
|
self.words = words |
|
self.pattern = r"\b(?:{})\b".format("|".join(words)) |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove words from. |
|
""" |
|
text = re.sub(self.pattern, "", text, flags=re.IGNORECASE) |
|
|
|
return text |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(pattern={self.pattern})" |
|
|
|
|
|
class FilterPOS(BaseTextTransform): |
|
"""Filter words by POS tags. |
|
|
|
Args: |
|
tags (list): List of POS tags to remove. |
|
engine (str): POS tagger to use. Must be one of "nltk" or "flair". Defaults to "nltk". |
|
""" |
|
|
|
def __init__(self, tags: list, engine: str = "nltk") -> None: |
|
super().__init__() |
|
self.tags = tags |
|
self.engine = engine |
|
|
|
if engine == "nltk": |
|
nltk.download("averaged_perceptron_tagger", quiet=True) |
|
nltk.download("punkt", quiet=True) |
|
self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x)) |
|
elif engine == "flair": |
|
self.tagger = SequenceTagger.load("flair/pos-english-fast").predict |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove words with specific POS tags from. |
|
""" |
|
if self.engine == "nltk": |
|
word_tags = self.tagger(text) |
|
text = " ".join([word for word, tag in word_tags if tag not in self.tags]) |
|
elif self.engine == "flair": |
|
sentence = Sentence(text) |
|
self.tagger(sentence) |
|
text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags]) |
|
|
|
return text |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(tags={self.tags}, engine={self.engine})" |
|
|
|
|
|
class FrequencyMinWordCount(BaseTextTransform): |
|
"""Keep only words that occur more than a minimum number of times in the input text. |
|
|
|
If the threshold is too strong and no words pass the threshold, the threshold is reduced to |
|
the most frequent word. |
|
|
|
Args: |
|
min_count (int): Minimum number of occurrences of a word to keep. |
|
""" |
|
|
|
def __init__(self, min_count) -> None: |
|
super().__init__() |
|
self.min_count = min_count |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove infrequent words from. |
|
""" |
|
if self.min_count <= 1: |
|
return text |
|
|
|
words = text.split() |
|
word_counts = {word: words.count(word) for word in words} |
|
|
|
|
|
max_word_count = max(word_counts.values() or [0]) |
|
min_count = max_word_count if self.min_count > max_word_count else self.min_count |
|
|
|
text = " ".join([word for word in words if word_counts[word] >= min_count]) |
|
|
|
return text |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(min_count={self.min_count})" |
|
|
|
|
|
class ReplaceSeparators(BaseTextTransform): |
|
"""Replace underscores and dashes with spaces.""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to replace separators in. |
|
""" |
|
text = re.sub(r"[_\-]", " ", text) |
|
|
|
return text |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}()" |
|
|
|
|
|
class RemoveDuplicates(BaseTextTransform): |
|
"""Remove duplicate words from the input text.""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to remove duplicate words from. |
|
""" |
|
text = " ".join(list(set(text.split()))) |
|
|
|
return text |
|
|
|
|
|
class TextCompose: |
|
"""Compose several transforms together. |
|
|
|
It differs from the torchvision.transforms.Compose class in that it applies the transforms to |
|
a string instead of a PIL Image or Tensor. In addition, it automatically join the list of |
|
input strings into a single string and splits the output string into a list of words. |
|
|
|
Args: |
|
transforms (list): List of transforms to compose. |
|
""" |
|
|
|
def __init__(self, transforms: list[BaseTextTransform]) -> None: |
|
self.transforms = transforms |
|
|
|
def __call__(self, text: Union[str, list[str]]) -> list[str]: |
|
""" |
|
Args: |
|
text (Union[str, list[str]]): Text to transform. |
|
""" |
|
if isinstance(text, list): |
|
text = " ".join(text) |
|
|
|
for t in self.transforms: |
|
text = t(text) |
|
return text.split() |
|
|
|
def __repr__(self) -> str: |
|
format_string = self.__class__.__name__ + "(" |
|
for t in self.transforms: |
|
format_string += "\n" |
|
format_string += f" {t}" |
|
format_string += "\n)" |
|
return format_string |
|
|
|
|
|
class ToLowercase(BaseTextTransform): |
|
"""Convert text to lowercase.""" |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to convert to lowercase. |
|
""" |
|
text = text.lower() |
|
|
|
return text |
|
|
|
|
|
class ToSingular(BaseTextTransform): |
|
"""Convert plural words to singular form.""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.transform = inflect.engine().singular_noun |
|
|
|
def __call__(self, text: str) -> str: |
|
""" |
|
Args: |
|
text (str): Text to convert to singular form. |
|
""" |
|
words = text.split() |
|
for i, word in enumerate(words): |
|
if not word.endswith("s"): |
|
continue |
|
|
|
if word[-2:] in ["ss", "us", "is"]: |
|
continue |
|
|
|
if word[-3:] in ["ies", "oes"]: |
|
continue |
|
|
|
words[i] = self.transform(word) or word |
|
|
|
text = " ".join(words) |
|
|
|
return text |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}()" |
|
|
|
|
|
def default_vocabulary_transforms() -> TextCompose: |
|
"""Preprocess input text with preprocessing transforms.""" |
|
words_to_drop = [ |
|
"image", |
|
"photo", |
|
"picture", |
|
"thumbnail", |
|
"logo", |
|
"symbol", |
|
"clipart", |
|
"portrait", |
|
"painting", |
|
"illustration", |
|
"icon", |
|
"profile", |
|
] |
|
pos_tags = ["NN", "NNS", "NNP", "NNPS", "JJ", "JJR", "JJS", "VBG", "VBN"] |
|
|
|
transforms = [] |
|
transforms.append(DropTokens()) |
|
transforms.append(DropURLs()) |
|
transforms.append(DropSpecialCharacters()) |
|
transforms.append(DropFileExtensions()) |
|
transforms.append(ReplaceSeparators()) |
|
transforms.append(DropShortWords(min_length=3)) |
|
transforms.append(DropNonAlpha()) |
|
transforms.append(ToLowercase()) |
|
transforms.append(ToSingular()) |
|
transforms.append(DropWords(words=words_to_drop)) |
|
transforms.append(FrequencyMinWordCount(min_count=2)) |
|
transforms.append(FilterPOS(tags=pos_tags, engine="flair")) |
|
transforms.append(RemoveDuplicates()) |
|
|
|
transforms = TextCompose(transforms) |
|
|
|
return transforms |
|
|