TwitterAccounts / scripts /translation.py
aus10powell's picture
Upload translation.py
33d6c4f
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
from transformers import AutoTokenizer
import re
class PersianTextProcessor:
"""
A class for processing Persian text.
Attributes:
model_size (str): The size of the MT5 model.
model_name (str): The name of the MT5 model.
tokenizer (MT5Tokenizer): The MT5 tokenizer.
model (MT5ForConditionalGeneration): The MT5 model.
Methods:
clean_persian_text(text): Cleans the given Persian text.
translate_text(persian_text): Translates the given Persian text to English.
"""
def __init__(self, model_size="small"):
"""
Initializes the PersianTextProcessor class.
Args:
model_size (str): The size of the MT5 model.
"""
self.model_size = model_size
self.model_name = f"persiannlp/mt5-{self.model_size}-parsinlu-opus-translation_fa_en"
self.tokenizer =MT5Tokenizer.from_pretrained(self.model_name) #AutoTokenizer.from_pretrained("persiannlp/mt5-small-parsinlu-opus-translation_fa_en")
self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name)
def clean_persian_text(self, text):
"""
Cleans the given Persian text by removing emojis, specific patterns, and replacing special characters.
Args:
text (str): The input Persian text.
Returns:
str: The cleaned Persian text.
"""
# Create a regular expression to match emojis.
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # emoticons
"\U0001F300-\U0001F5FF" # symbols & pictographs
"\U0001F680-\U0001F6FF" # transport & map symbols
"\U0001F1E0-\U0001F1FF" # flags (iOS)
"]+",
flags=re.UNICODE,
)
# Create a regular expression to match specific patterns.
pattern = "[\U0001F90D\U00002764\U0001F91F][\U0000FE0F\U0000200D]*"
# Remove emojis, specific patterns, and special characters from the text.
text = emoji_pattern.sub("", text)
text = re.sub(pattern, "", text)
text = text.replace("✌", "")
text = text.replace("@", "")
text = text.replace("#", "hashtag_")
return text
def run_model(self, input_string, **generator_args):
"""
Runs the MT5 model on the given input string.
Args:
input_string (str): The input string.
**generator_args: Additional arguments to pass to the MT5 model.
Returns:
str: The output of the MT5 model.
"""
# Encode the input string as a sequence of tokens.
input_ids = self.tokenizer.encode(input_string, return_tensors="pt")
# Generate the output text.
res = self.model.generate(input_ids, **generator_args)
# Decode the output text to a string.
output = self.tokenizer.batch_decode(res, skip_special_tokens=True)
return output
def translate_text(self, persian_text):
"""
Translates the given Persian text to English.
Args:
persian_text (str): The Persian text to translate.
Returns:
str: The translated text.
"""
# Clean the Persian text.
text_cleaned = self.clean_persian_text(persian_text)
# Translate the cleaned text.
translated_text = self.run_model(input_string=text_cleaned)
return translated_text