Spaces:
Runtime error
Runtime error
from transformers import MT5ForConditionalGeneration, MT5Tokenizer | |
from transformers import AutoTokenizer | |
import re | |
class PersianTextProcessor: | |
""" | |
A class for processing Persian text. | |
Attributes: | |
model_size (str): The size of the MT5 model. | |
model_name (str): The name of the MT5 model. | |
tokenizer (MT5Tokenizer): The MT5 tokenizer. | |
model (MT5ForConditionalGeneration): The MT5 model. | |
Methods: | |
clean_persian_text(text): Cleans the given Persian text. | |
translate_text(persian_text): Translates the given Persian text to English. | |
""" | |
def __init__(self, model_size="small"): | |
""" | |
Initializes the PersianTextProcessor class. | |
Args: | |
model_size (str): The size of the MT5 model. | |
""" | |
self.model_size = model_size | |
self.model_name = f"persiannlp/mt5-{self.model_size}-parsinlu-opus-translation_fa_en" | |
self.tokenizer =MT5Tokenizer.from_pretrained(self.model_name) #AutoTokenizer.from_pretrained("persiannlp/mt5-small-parsinlu-opus-translation_fa_en") | |
self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name) | |
def clean_persian_text(self, text): | |
""" | |
Cleans the given Persian text by removing emojis, specific patterns, and replacing special characters. | |
Args: | |
text (str): The input Persian text. | |
Returns: | |
str: The cleaned Persian text. | |
""" | |
# Create a regular expression to match emojis. | |
emoji_pattern = re.compile( | |
"[" | |
"\U0001F600-\U0001F64F" # emoticons | |
"\U0001F300-\U0001F5FF" # symbols & pictographs | |
"\U0001F680-\U0001F6FF" # transport & map symbols | |
"\U0001F1E0-\U0001F1FF" # flags (iOS) | |
"]+", | |
flags=re.UNICODE, | |
) | |
# Create a regular expression to match specific patterns. | |
pattern = "[\U0001F90D\U00002764\U0001F91F][\U0000FE0F\U0000200D]*" | |
# Remove emojis, specific patterns, and special characters from the text. | |
text = emoji_pattern.sub("", text) | |
text = re.sub(pattern, "", text) | |
text = text.replace("✌", "") | |
text = text.replace("@", "") | |
text = text.replace("#", "hashtag_") | |
return text | |
def run_model(self, input_string, **generator_args): | |
""" | |
Runs the MT5 model on the given input string. | |
Args: | |
input_string (str): The input string. | |
**generator_args: Additional arguments to pass to the MT5 model. | |
Returns: | |
str: The output of the MT5 model. | |
""" | |
# Encode the input string as a sequence of tokens. | |
input_ids = self.tokenizer.encode(input_string, return_tensors="pt") | |
# Generate the output text. | |
res = self.model.generate(input_ids, **generator_args) | |
# Decode the output text to a string. | |
output = self.tokenizer.batch_decode(res, skip_special_tokens=True) | |
return output | |
def translate_text(self, persian_text): | |
""" | |
Translates the given Persian text to English. | |
Args: | |
persian_text (str): The Persian text to translate. | |
Returns: | |
str: The translated text. | |
""" | |
# Clean the Persian text. | |
text_cleaned = self.clean_persian_text(persian_text) | |
# Translate the cleaned text. | |
translated_text = self.run_model(input_string=text_cleaned) | |
return translated_text | |