Spaces:
Runtime error
Runtime error
aus10powell
commited on
Commit
•
33d6c4f
1
Parent(s):
b20b18b
Upload translation.py
Browse files- scripts/translation.py +104 -0
scripts/translation.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
class PersianTextProcessor:
|
7 |
+
"""
|
8 |
+
A class for processing Persian text.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
model_size (str): The size of the MT5 model.
|
12 |
+
model_name (str): The name of the MT5 model.
|
13 |
+
tokenizer (MT5Tokenizer): The MT5 tokenizer.
|
14 |
+
model (MT5ForConditionalGeneration): The MT5 model.
|
15 |
+
|
16 |
+
Methods:
|
17 |
+
clean_persian_text(text): Cleans the given Persian text.
|
18 |
+
translate_text(persian_text): Translates the given Persian text to English.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, model_size="small"):
|
22 |
+
"""
|
23 |
+
Initializes the PersianTextProcessor class.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
model_size (str): The size of the MT5 model.
|
27 |
+
"""
|
28 |
+
self.model_size = model_size
|
29 |
+
self.model_name = f"persiannlp/mt5-{self.model_size}-parsinlu-opus-translation_fa_en"
|
30 |
+
self.tokenizer =MT5Tokenizer.from_pretrained(self.model_name) #AutoTokenizer.from_pretrained("persiannlp/mt5-small-parsinlu-opus-translation_fa_en")
|
31 |
+
self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name)
|
32 |
+
|
33 |
+
def clean_persian_text(self, text):
|
34 |
+
"""
|
35 |
+
Cleans the given Persian text by removing emojis, specific patterns, and replacing special characters.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
text (str): The input Persian text.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
str: The cleaned Persian text.
|
42 |
+
"""
|
43 |
+
# Create a regular expression to match emojis.
|
44 |
+
emoji_pattern = re.compile(
|
45 |
+
"["
|
46 |
+
"\U0001F600-\U0001F64F" # emoticons
|
47 |
+
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
48 |
+
"\U0001F680-\U0001F6FF" # transport & map symbols
|
49 |
+
"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
50 |
+
"]+",
|
51 |
+
flags=re.UNICODE,
|
52 |
+
)
|
53 |
+
|
54 |
+
# Create a regular expression to match specific patterns.
|
55 |
+
pattern = "[\U0001F90D\U00002764\U0001F91F][\U0000FE0F\U0000200D]*"
|
56 |
+
|
57 |
+
# Remove emojis, specific patterns, and special characters from the text.
|
58 |
+
text = emoji_pattern.sub("", text)
|
59 |
+
text = re.sub(pattern, "", text)
|
60 |
+
text = text.replace("✌", "")
|
61 |
+
text = text.replace("@", "")
|
62 |
+
text = text.replace("#", "hashtag_")
|
63 |
+
|
64 |
+
return text
|
65 |
+
|
66 |
+
def run_model(self, input_string, **generator_args):
|
67 |
+
"""
|
68 |
+
Runs the MT5 model on the given input string.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
input_string (str): The input string.
|
72 |
+
**generator_args: Additional arguments to pass to the MT5 model.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
str: The output of the MT5 model.
|
76 |
+
"""
|
77 |
+
# Encode the input string as a sequence of tokens.
|
78 |
+
input_ids = self.tokenizer.encode(input_string, return_tensors="pt")
|
79 |
+
|
80 |
+
# Generate the output text.
|
81 |
+
res = self.model.generate(input_ids, **generator_args)
|
82 |
+
|
83 |
+
# Decode the output text to a string.
|
84 |
+
output = self.tokenizer.batch_decode(res, skip_special_tokens=True)
|
85 |
+
|
86 |
+
return output
|
87 |
+
|
88 |
+
def translate_text(self, persian_text):
|
89 |
+
"""
|
90 |
+
Translates the given Persian text to English.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
persian_text (str): The Persian text to translate.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
str: The translated text.
|
97 |
+
"""
|
98 |
+
# Clean the Persian text.
|
99 |
+
text_cleaned = self.clean_persian_text(persian_text)
|
100 |
+
|
101 |
+
# Translate the cleaned text.
|
102 |
+
translated_text = self.run_model(input_string=text_cleaned)
|
103 |
+
|
104 |
+
return translated_text
|