Spaces:
Runtime error
Runtime error
import re | |
import numpy as np | |
import fasttext | |
import sentencepiece | |
import kenlm | |
import pathlib | |
from languages_id import langs_id | |
from parameters_filtering import parameters_filtering | |
from normalization import normalization | |
from stopwords import stopwords | |
from badwords import badwords | |
class LoadParameters: | |
def load_parameters(lang_dataset_id): | |
if lang_dataset_id in parameters_filtering: | |
param = parameters_filtering[lang_dataset_id] | |
else: | |
param = parameters_filtering["default"] | |
return param | |
def load_stopwords(lang_dataset_id): | |
stopwords_lang_id = langs_id.loc[ | |
langs_id["dataset_id"] == lang_dataset_id, "stopwords_id" | |
].iloc[0] | |
if stopwords_lang_id: | |
stopwords_lang = set(stopwords[stopwords_lang_id]) | |
else: | |
stopwords_lang = None | |
return stopwords_lang | |
def load_badwords(lang_dataset_id): | |
badwords_lang_id = langs_id.loc[ | |
langs_id["dataset_id"] == lang_dataset_id, "badwords_id" | |
].iloc[0] | |
if badwords_lang_id: | |
badwords_lang = set(badwords[badwords_lang_id]) | |
else: | |
badwords_lang = None | |
return badwords_lang | |
def load_model_lang_id(lang_dataset_id, path_fasttext_model): | |
fasttext_lang_id = langs_id.loc[ | |
langs_id["dataset_id"] == lang_dataset_id, "fasttext_id" | |
].iloc[0] | |
if fasttext_lang_id: | |
model_lang_id = fasttext.load_model(path_fasttext_model) | |
else: | |
model_lang_id = None | |
return model_lang_id | |
def load_sentencepiece_model(lang_dataset_id, path_sentencepiece_model): | |
sentencepiece_lang_id = langs_id.loc[ | |
langs_id["dataset_id"] == lang_dataset_id, "sentencepiece_id" | |
].iloc[0] | |
if sentencepiece_lang_id: | |
sentencepiece_model = sentencepiece.SentencePieceProcessor() | |
sentencepiece_model.load(path_sentencepiece_model) | |
else: | |
sentencepiece_model = None | |
return sentencepiece_model | |
def load_kenlm_model(lang_dataset_id, path_kenlm_model): | |
kenlm_lang_id = langs_id.loc[ | |
langs_id["dataset_id"] == lang_dataset_id, "kenlm_id" | |
].iloc[0] | |
if kenlm_lang_id: | |
kenlm_model = kenlm.Model(path_kenlm_model) | |
else: | |
kenlm_model = None | |
return kenlm_model | |
class ModifyingDocuments: | |
def remove_empty_el_from_list(list_): | |
return [el for el in list_ if el] | |
def remove_non_printing_characters(document, non_printing_characters_re): | |
return non_printing_characters_re.sub("", document) | |
def uniform_whitespace( | |
document, | |
whitespace=[ | |
" ", | |
"β", | |
"β", | |
"β―", | |
"β ", | |
"γ", | |
"β", | |
"Β ", | |
"β", | |
"β", | |
"οΏΌ", | |
"Β", | |
], | |
): | |
"""There are different whitespace characters.""" | |
whitespace = set(whitespace) | |
document = "".join( | |
[char if char not in whitespace else " " for char in document] | |
) | |
return document | |
def replace_digits_with_zeros(document, digits_re): | |
return digits_re.sub("0", document) | |
def replace_unicode_punctuation(document, unicode_punctuation): | |
return "".join(unicode_punctuation.get(c, c) for c in document) | |
def normalization( | |
document, | |
remove_non_printing_characters, | |
strip, | |
lower_case, | |
uniform_whitespace, | |
replace_digits_with_zeros, | |
replace_unicode_punctuation, | |
non_printing_characters_re=normalization["non_printing_characters_re"], | |
digits_re=normalization["digits_re"], | |
unicode_punctuation=normalization["unicode_punctuation"], | |
): | |
if remove_non_printing_characters: | |
document = ModifyingDocuments.remove_non_printing_characters( | |
document, non_printing_characters_re | |
) | |
if strip: | |
document = document.strip() | |
if not document: | |
return document | |
if lower_case: | |
document = document.lower() | |
if uniform_whitespace: | |
document = ModifyingDocuments.uniform_whitespace(document) | |
if replace_digits_with_zeros: | |
document = ModifyingDocuments.replace_digits_with_zeros(document, digits_re) | |
if replace_unicode_punctuation: | |
document = ModifyingDocuments.replace_unicode_punctuation( | |
document, unicode_punctuation | |
) | |
return document | |
def tokenization(document, sentencepiece_model, join_on_whitespace): | |
document_tokenized = sentencepiece_model.encode_as_pieces(document) | |
if join_on_whitespace: | |
document_tokenized = " ".join(document_tokenized) | |
return document_tokenized | |
def split_on_whitespace( | |
document, | |
new_line=False, | |
tab=False, | |
): | |
"""This method also removes concatenated spaces.""" | |
sep = [" "] + new_line * ["\n"] + tab * ["\t"] | |
sep = "|".join(sep) | |
split_document = re.split(sep, document) | |
split_document = ModifyingDocuments.remove_empty_el_from_list(split_document) | |
return split_document | |
def strip(document, strip_characters): | |
"""Way faster than document.strip(strip_characters) | |
since strip_characters is now a set instead of a str, | |
and it contains a lot of elements (all the emojis).""" | |
if not document: | |
return document | |
beg_ind = 0 | |
end_ind = len(document) | |
for i in range(len(document)): | |
if document[i] in strip_characters: | |
beg_ind += 1 | |
else: | |
break | |
for i in range(1, len(document) + 1): | |
if document[-i] in strip_characters: | |
end_ind -= 1 | |
else: | |
break | |
document_stripped = document[beg_ind:end_ind] | |
return document_stripped | |
def get_words_from_document( | |
document, sentencepiece_model_tok, lower_case, strip_characters | |
): | |
"""Get words from a document. Non reversible since the document | |
is split on multiple characters, words are stripped of | |
special characters and characters are converted to lower case. | |
Useful to compute ratios, like the stopwords ratio.""" | |
if sentencepiece_model_tok: | |
document_normalized = ModifyingDocuments.normalization( | |
document=document, | |
remove_non_printing_characters=True, | |
strip=True, | |
lower_case=True, | |
uniform_whitespace=True, | |
replace_digits_with_zeros=True, | |
replace_unicode_punctuation=True, | |
) | |
words = ModifyingDocuments.tokenization( | |
document_normalized, sentencepiece_model_tok, join_on_whitespace=False | |
) | |
else: | |
words = ModifyingDocuments.split_on_whitespace( | |
document, new_line=True, tab=True | |
) | |
if lower_case: | |
words = [word.lower() for word in words] | |
if strip_characters: | |
words = [ModifyingDocuments.strip(word, strip_characters) for word in words] | |
words = ModifyingDocuments.remove_empty_el_from_list(words) | |
return words | |
def words_augmentation(words, group_size, join_char): | |
"""Augment words, especially for Chinese (without a space between words) | |
and Vietnamese (with a space between syllables).""" | |
augmentation = [ | |
join_char.join(words[i : i + group_size]) | |
for i in range(len(words) - group_size + 1) | |
] | |
return augmentation | |
def split_on_newline_tab_whitespace(document): | |
"""First split on "\n", then on "\t", then on " ".""" | |
sentences = document.split("\n") | |
sentences = [sentence.split("\t") for sentence in sentences] | |
sentences = [ | |
[ | |
ModifyingDocuments.split_on_whitespace(subsentence) | |
for subsentence in sentence | |
] | |
for sentence in sentences | |
] | |
return sentences | |
def merge_on_whitespace_tab_newline(sentences): | |
"""Invert the method split_on_newline_tab_whitespace. | |
Removes concatenated separators.""" | |
sentences = [ | |
[" ".join(subsentence) for subsentence in sentence if subsentence] | |
for sentence in sentences | |
] | |
sentences = ["\t".join(sentence) for sentence in sentences if sentence] | |
if not sentences: | |
return "" | |
document = "\n".join(sentences) | |
return document | |
def should_keep_word_with_incorrect_substrings( | |
word, strip_characters, incorrect_word_substrings | |
): | |
word = ModifyingDocuments.strip(word, strip_characters) | |
should_keep = all( | |
[(i_substr not in word) for i_substr in incorrect_word_substrings] | |
) | |
return should_keep | |
def remove_words_with_incorrect_substrings( | |
document, | |
strip_characters, | |
incorrect_word_substrings, | |
): | |
sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document) | |
sentences = [ | |
[ | |
[ | |
word | |
for word in subsentence | |
if ModifyingDocuments.should_keep_word_with_incorrect_substrings( | |
word, strip_characters, incorrect_word_substrings | |
) | |
] | |
for subsentence in sentence | |
] | |
for sentence in sentences | |
] | |
document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences) | |
return document | |
def should_keep_long_word(word, strip_characters, length_word_max_cutoff): | |
"""If the word is too long but it contains only one | |
special character, it might be a concatenation of one word, | |
a punctuation, and another word, with no space between them. | |
In this case, we give the word a pass.""" | |
if len(word) <= length_word_max_cutoff: | |
return True | |
word = ModifyingDocuments.strip(word, strip_characters) | |
if not word: # The word consisted only of strip characters | |
return False | |
if len(word) <= length_word_max_cutoff: | |
return True | |
return False | |
def remove_long_words( | |
document, | |
strip_characters, | |
length_word_max_cutoff, | |
): | |
sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document) | |
sentences = [ | |
[ | |
[ | |
word | |
for word in subsentence | |
if ModifyingDocuments.should_keep_long_word( | |
word, | |
strip_characters, | |
length_word_max_cutoff, | |
) | |
] | |
for subsentence in sentence | |
] | |
for sentence in sentences | |
] | |
document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences) | |
return document | |
def modifying_documents( | |
document, | |
cond_uniform_whitespace, | |
cond_replace_unicode_punctuation, | |
cond_remove_words_with_incorrect_substrings, | |
strip_characters, | |
incorrect_word_substrings, | |
cond_remove_long_words, | |
length_word_max_cutoff, | |
): | |
document = ModifyingDocuments.normalization( | |
document=document, | |
remove_non_printing_characters=False, | |
strip=True, | |
lower_case=False, | |
uniform_whitespace=cond_uniform_whitespace, | |
replace_digits_with_zeros=False, | |
replace_unicode_punctuation=cond_replace_unicode_punctuation, | |
) | |
if cond_remove_words_with_incorrect_substrings: | |
document = ModifyingDocuments.remove_words_with_incorrect_substrings( | |
document, | |
strip_characters, | |
incorrect_word_substrings, | |
) | |
if cond_remove_long_words: | |
document = ModifyingDocuments.remove_long_words( | |
document, | |
strip_characters, | |
length_word_max_cutoff, | |
) | |
return document | |
class FunctionDatasetModifyingDocuments: | |
def __init__(self, lang_dataset_id): | |
self.lang_dataset_id = lang_dataset_id | |
self.param = LoadParameters.load_parameters(lang_dataset_id) | |
def __call__(self, example): | |
example["text"] = ModifyingDocuments.modifying_documents( | |
document=example["text"], | |
cond_uniform_whitespace=self.param["cond_uniform_whitespace"], | |
cond_replace_unicode_punctuation=self.param[ | |
"cond_replace_unicode_punctuation" | |
], | |
cond_remove_words_with_incorrect_substrings=self.param[ | |
"cond_remove_words_with_incorrect_substrings" | |
], | |
strip_characters=self.param["strip_characters"], | |
incorrect_word_substrings=self.param["incorrect_word_substrings"], | |
cond_remove_long_words=self.param["cond_remove_long_words"], | |
length_word_max_cutoff=self.param["length_word_max_cutoff"], | |
) | |
return example | |
def __reduce__(self): | |
return (self.__class__, (self.lang_dataset_id,)) | |
class Filtering: | |
def check_number_words( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
number_words_min_cutoff, | |
number_words_max_cutoff, | |
): | |
words = ModifyingDocuments.get_words_from_document( | |
document, | |
sentencepiece_model_tok, | |
lower_case=False, | |
strip_characters=strip_characters, | |
) | |
cond = (len(words) >= number_words_min_cutoff) and ( | |
len(words) <= number_words_max_cutoff | |
) | |
return cond | |
def compute_repetitions_ratio(document, repetitions_length): | |
def get_freq_ngrams(document, n): | |
ngrams = [document[i : i + n] for i in range(len(document) - n + 1)] | |
freq_ngrams = {} | |
for ngram in ngrams: | |
freq_ngrams[ngram] = freq_ngrams.get(ngram, 0) + 1 | |
return freq_ngrams | |
freq_ngrams = get_freq_ngrams(document, repetitions_length) | |
if len(freq_ngrams) == 0: | |
return 0 | |
freq_ngrams = list(freq_ngrams.values()) | |
freq_ngrams = sorted(freq_ngrams, reverse=True) | |
num_rep_ngrams = int(np.sqrt(len(freq_ngrams))) | |
repetitions_ratio = sum(freq_ngrams[:num_rep_ngrams]) / sum(freq_ngrams) | |
return repetitions_ratio | |
def check_repetitions_removal( | |
document, | |
repetitions_length, | |
repetitions_max_cutoff, | |
): | |
repetitions_ratio = Filtering.compute_repetitions_ratio( | |
document, repetitions_length | |
) | |
cond = repetitions_ratio <= repetitions_max_cutoff | |
return cond | |
def compute_special_characters_ratio(document, special_characters): | |
special_characters_ratio = len( | |
[char for char in document if char in special_characters] | |
) / len(document) | |
return special_characters_ratio | |
def check_special_characters( | |
document, | |
special_characters, | |
special_characters_max_cutoff, | |
): | |
special_characters_ratio = Filtering.compute_special_characters_ratio( | |
document, special_characters | |
) | |
cond = special_characters_ratio <= special_characters_max_cutoff | |
return cond | |
def compute_stopwords_ratio( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
stopwords, | |
): | |
words = ModifyingDocuments.get_words_from_document( | |
document, | |
sentencepiece_model_tok, | |
lower_case=True, | |
strip_characters=strip_characters, | |
) | |
if not words: | |
return 0 | |
augmentation = [] | |
if cond_words_augmentation: | |
augmentation = [ | |
ModifyingDocuments.words_augmentation( | |
words, group_size, words_augmentation_join_char | |
) | |
for group_size in words_augmentation_group_sizes | |
] | |
augmentation = [word for augm in augmentation for word in augm] | |
stopwords_ratio = len( | |
[word for word in words + augmentation if word in stopwords] | |
) / len(words) | |
if stopwords_ratio > 1.0: | |
stopwords_ratio = 1.0 | |
return stopwords_ratio | |
def check_stopwords( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
stopwords, | |
stopwords_min_cutoff, | |
): | |
cond = True | |
if stopwords: | |
stopwords_ratio = Filtering.compute_stopwords_ratio( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
stopwords, | |
) | |
cond = stopwords_ratio >= stopwords_min_cutoff | |
return cond | |
def compute_badwords_ratio( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
badwords, | |
): | |
words = ModifyingDocuments.get_words_from_document( | |
document, | |
sentencepiece_model_tok, | |
lower_case=True, | |
strip_characters=strip_characters, | |
) | |
if not words: | |
return 0 | |
augmentation = [] | |
if cond_words_augmentation: | |
augmentation = [ | |
ModifyingDocuments.words_augmentation( | |
words, group_size, words_augmentation_join_char | |
) | |
for group_size in words_augmentation_group_sizes | |
] | |
augmentation = [word for augm in augmentation for word in augm] | |
badwords_ratio = len( | |
[word for word in words + augmentation if word in badwords] | |
) / len(words) | |
if badwords_ratio > 1.0: | |
badwords_ratio = 1.0 | |
for word in augmentation: | |
if word in badwords: | |
print(word) | |
return badwords_ratio | |
def check_badwords( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
badwords, | |
badwords_max_cutoff, | |
): | |
cond = True | |
if badwords: | |
badwords_ratio = Filtering.compute_badwords_ratio( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
badwords, | |
) | |
cond = badwords_ratio <= badwords_max_cutoff | |
return cond | |
def compute_lang_id_pred_score(document, model_lang_id): | |
document = document.lower().replace("\n", " ") | |
pred = model_lang_id.predict(document) | |
lang_pred_fasttext_id = pred[0][0].replace("__label__", "") | |
score_pred = pred[1][0] | |
lang_pred_dataset_id = langs_id.loc[ | |
langs_id["fasttext_id"] == lang_pred_fasttext_id, "dataset_id" | |
] | |
if len(lang_pred_dataset_id) > 0: | |
lang_pred_dataset_id = lang_pred_dataset_id.iloc[0] | |
else: | |
lang_pred_dataset_id = "unknown" | |
return lang_pred_dataset_id, score_pred | |
def check_lang_id( | |
document, | |
lang_dataset_id, | |
model_lang_id, | |
lang_id_min_cutoff, | |
): | |
cond = True | |
if model_lang_id: | |
lang_pred_dataset_id, score_pred = Filtering.compute_lang_id_pred_score( | |
document, model_lang_id | |
) | |
cond = (lang_pred_dataset_id == lang_dataset_id) and ( | |
score_pred >= lang_id_min_cutoff | |
) | |
return cond | |
def compute_perplexity_score(document, sentencepiece_model, kenlm_model): | |
document = ModifyingDocuments.normalization( | |
document=document, | |
remove_non_printing_characters=True, | |
strip=True, | |
lower_case=True, | |
uniform_whitespace=True, | |
replace_digits_with_zeros=True, | |
replace_unicode_punctuation=True, | |
) | |
document = ModifyingDocuments.tokenization( | |
document, sentencepiece_model, join_on_whitespace=True | |
) | |
doc_log_score, doc_length = 0, 0 | |
for line in document.split("\n"): | |
log_score = kenlm_model.score(line) | |
length = len(line.split()) + 1 | |
doc_log_score += log_score | |
doc_length += length | |
pp_score = 10.0 ** (-doc_log_score / doc_length) | |
pp_score = round(pp_score, 1) | |
return pp_score | |
def check_perplexity( | |
document, | |
sentencepiece_model, | |
kenlm_model, | |
perplexity_max_cutoff, | |
): | |
cond = True | |
if kenlm_model: | |
score = Filtering.compute_perplexity_score( | |
document, sentencepiece_model, kenlm_model | |
) | |
cond = score <= perplexity_max_cutoff | |
return cond | |
def filtering( | |
document, | |
cond_check_number_words, | |
sentencepiece_model_tok, | |
strip_characters, | |
number_words_min_cutoff, | |
number_words_max_cutoff, | |
cond_check_repetitions_removal, | |
repetitions_length, | |
repetitions_max_cutoff, | |
cond_check_special_characters, | |
special_characters, | |
special_characters_max_cutoff, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
cond_check_stopwords, | |
stopwords, | |
stopwords_min_cutoff, | |
cond_check_badwords, | |
badwords, | |
badwords_max_cutoff, | |
cond_check_lang_id, | |
lang_dataset_id, | |
model_lang_id, | |
lang_id_min_cutoff, | |
cond_check_perplexity, | |
sentencepiece_model, | |
kenlm_model, | |
perplexity_max_cutoff, | |
): | |
if cond_check_number_words: | |
if not Filtering.check_number_words( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
number_words_min_cutoff, | |
number_words_max_cutoff, | |
): | |
return False | |
if cond_check_repetitions_removal: | |
if not Filtering.check_repetitions_removal( | |
document, | |
repetitions_length, | |
repetitions_max_cutoff, | |
): | |
return False | |
if cond_check_special_characters: | |
if not Filtering.check_special_characters( | |
document, | |
special_characters, | |
special_characters_max_cutoff, | |
): | |
return False | |
if cond_check_stopwords: | |
if not Filtering.check_stopwords( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
stopwords, | |
stopwords_min_cutoff, | |
): | |
return False | |
if cond_check_badwords: | |
if not Filtering.check_badwords( | |
document, | |
sentencepiece_model_tok, | |
strip_characters, | |
cond_words_augmentation, | |
words_augmentation_group_sizes, | |
words_augmentation_join_char, | |
badwords, | |
badwords_max_cutoff, | |
): | |
return False | |
if cond_check_lang_id: | |
if not Filtering.check_lang_id( | |
document, | |
lang_dataset_id, | |
model_lang_id, | |
lang_id_min_cutoff, | |
): | |
return False | |
if cond_check_perplexity: | |
if not Filtering.check_perplexity( | |
document, | |
sentencepiece_model, | |
kenlm_model, | |
perplexity_max_cutoff, | |
): | |
return False | |
return True | |
class FunctionDatasetFiltering: | |
def __init__( | |
self, | |
lang_dataset_id, | |
path_fasttext_model, | |
path_sentencepiece_model, | |
path_kenlm_model, | |
): | |
self.lang_dataset_id = lang_dataset_id | |
self.path_fasttext_model = path_fasttext_model | |
self.path_sentencepiece_model = path_sentencepiece_model | |
self.path_kenlm_model = path_kenlm_model | |
self.param = LoadParameters.load_parameters(lang_dataset_id) | |
self.stopwords = LoadParameters.load_stopwords(lang_dataset_id) | |
self.badwords = LoadParameters.load_badwords(lang_dataset_id) | |
self.model_lang_id = LoadParameters.load_model_lang_id( | |
lang_dataset_id, path_fasttext_model | |
) | |
self.sentencepiece_model = LoadParameters.load_sentencepiece_model( | |
lang_dataset_id, path_sentencepiece_model | |
) | |
self.sentencepiece_model_tok = ( | |
self.sentencepiece_model if self.param["tokenization"] else None | |
) | |
self.kenlm_model = LoadParameters.load_kenlm_model( | |
lang_dataset_id, path_kenlm_model | |
) | |
def __call__(self, example): | |
keep_example = Filtering.filtering( | |
document=example["text"], | |
cond_check_number_words=self.param["cond_check_number_words"], | |
sentencepiece_model_tok=self.sentencepiece_model_tok, | |
strip_characters=self.param["strip_characters"], | |
number_words_min_cutoff=self.param["number_words_min_cutoff"], | |
number_words_max_cutoff=self.param["number_words_max_cutoff"], | |
cond_check_repetitions_removal=self.param["check_repetitions_removal"], | |
repetitions_length=self.param["repetitions_length"], | |
repetitions_max_cutoff=self.param["repetitions_max_cutoff"], | |
cond_check_special_characters=self.param["cond_check_special_characters"], | |
special_characters=self.param["special_characters"], | |
special_characters_max_cutoff=self.param["special_characters_max_cutoff"], | |
cond_words_augmentation=self.param["cond_words_augmentation"], | |
words_augmentation_group_sizes=self.param["words_augmentation_group_sizes"], | |
words_augmentation_join_char=self.param["words_augmentation_join_char"], | |
cond_check_stopwords=self.param["cond_check_stopwords"], | |
stopwords=self.stopwords, | |
stopwords_min_cutoff=self.param["stopwords_min_cutoff"], | |
cond_check_badwords=self.param["cond_check_badwords"], | |
badwords=self.badwords, | |
badwords_max_cutoff=self.param["badwords_max_cutoff"], | |
cond_check_lang_id=self.param["cond_check_lang_id"], | |
lang_dataset_id=self.lang_dataset_id, | |
model_lang_id=self.model_lang_id, | |
lang_id_min_cutoff=self.param["lang_id_min_cutoff"], | |
cond_check_perplexity=self.param["cond_check_perplexity"], | |
sentencepiece_model=self.sentencepiece_model, | |
kenlm_model=self.kenlm_model, | |
perplexity_max_cutoff=self.param["perplexity_max_cutoff"], | |
) | |
return keep_example | |
def __reduce__(self): | |
return ( | |
self.__class__, | |
( | |
self.lang_dataset_id, | |
self.path_fasttext_model, | |
self.path_sentencepiece_model, | |
self.path_kenlm_model, | |
), | |
) | |
class DatasetFiltering: | |
def __init__( | |
self, | |
dataset, | |
lang_dataset_id, | |
path_fasttext_model, | |
path_sentencepiece_model, | |
path_kenlm_model, | |
num_proc, | |
path_dir_save_dataset, | |
): | |
self.ds = dataset | |
self.lang_dataset_id = lang_dataset_id | |
self.path_fasttext_model = path_fasttext_model | |
self.path_sentencepiece_model = path_sentencepiece_model | |
self.path_kenlm_model = path_kenlm_model | |
self.num_proc = num_proc | |
self.path_dir_save_dataset = path_dir_save_dataset | |
def modifying_documents(self): | |
dataset_modifying_documents = FunctionDatasetModifyingDocuments( | |
self.lang_dataset_id | |
) | |
self.ds = self.ds.map(dataset_modifying_documents, num_proc=self.num_proc) | |
def filtering(self): | |
func_dataset_filtering = FunctionDatasetFiltering( | |
self.lang_dataset_id, | |
self.path_fasttext_model, | |
self.path_sentencepiece_model, | |
self.path_kenlm_model, | |
) | |
self.ds = self.ds.filter(func_dataset_filtering, num_proc=self.num_proc) | |
def save_dataset(self): | |
pathlib.Path(self.path_dir_save_dataset).mkdir(parents=True, exist_ok=True) | |
path_dir_save_dataset = pathlib.PurePath( | |
self.path_dir_save_dataset, self.lang_dataset_id | |
) | |
pathlib.Path(path_dir_save_dataset).mkdir(parents=True, exist_ok=True) | |
self.ds.save_to_disk(path_dir_save_dataset) | |