vteam27 commited on
Commit
af923d2
·
1 Parent(s): 71ad94c

"Added UI"

Browse files
Files changed (2) hide show
  1. app.py +81 -9
  2. lang_list.py +255 -0
app.py CHANGED
@@ -1,16 +1,88 @@
1
  import gradio as gr
 
 
 
 
 
 
2
  from transformers import SeamlessM4TForTextToText
3
- from transformers import AutoProcessor, SeamlessM4TModel
4
  model = SeamlessM4TForTextToText.from_pretrained("facebook/hf-seamless-m4t-medium")
5
  processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
6
 
7
- text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")
8
- output_tokens = model.generate(**text_inputs, tgt_lang="pan")
9
- translated_text_from_text = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
10
- print(translated_text_from_text)
11
 
12
- def greet(name):
13
- return translated_text_from_text
14
 
15
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
16
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from lang_list import (
3
+ LANGUAGE_NAME_TO_CODE,
4
+ T2TT_TARGET_LANGUAGE_NAMES,
5
+ TEXT_SOURCE_LANGUAGE_NAMES,
6
+ )
7
+ DEFAULT_TARGET_LANGUAGE = "English"
8
  from transformers import SeamlessM4TForTextToText
9
+ from transformers import AutoProcessor
10
  model = SeamlessM4TForTextToText.from_pretrained("facebook/hf-seamless-m4t-medium")
11
  processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
12
 
13
+ # text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")
14
+ # output_tokens = model.generate(**text_inputs, tgt_lang="pan")
15
+ # translated_text_from_text = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
16
+ # print(translated_text_from_text)
17
 
 
 
18
 
19
+ def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
20
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
21
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
22
+ text_inputs = processor(text = input_text, src_lang=source_language_code , return_tensors="pt")
23
+ output = model.generate(**text_inputs, tgt_lang=target_language_code)
24
+ output_tokens = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
25
+ return str(output)
26
+
27
+
28
+
29
+ with gr.Blocks() as demo_t2tt:
30
+ with gr.Row():
31
+ with gr.Column():
32
+ with gr.Group():
33
+ input_text = gr.Textbox(label="Input text")
34
+ with gr.Row():
35
+ source_language = gr.Dropdown(
36
+ label="Source language",
37
+ choices=TEXT_SOURCE_LANGUAGE_NAMES,
38
+ value="English",
39
+ )
40
+ target_language = gr.Dropdown(
41
+ label="Target language",
42
+ choices=T2TT_TARGET_LANGUAGE_NAMES,
43
+ value=DEFAULT_TARGET_LANGUAGE,
44
+ )
45
+ btn = gr.Button("Translate")
46
+ with gr.Column():
47
+ output_text = gr.Textbox(label="Translated text")
48
+
49
+ gr.Examples(
50
+ examples=[
51
+ [
52
+ "My favorite animal is the elephant.",
53
+ "English",
54
+ "French",
55
+ ],
56
+ [
57
+ "My favorite animal is the elephant.",
58
+ "English",
59
+ "Mandarin Chinese",
60
+ ],
61
+ [
62
+ "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
63
+ "English",
64
+ "Hindi",
65
+ ],
66
+ [
67
+ "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
68
+ "English",
69
+ "Spanish",
70
+ ],
71
+ ],
72
+ inputs=[input_text, source_language, target_language],
73
+ outputs=output_text,
74
+ fn=run_t2tt,
75
+ cache_examples=True,
76
+ api_name=False,
77
+ )
78
+
79
+ gr.on(
80
+ triggers=[input_text.submit, btn.click],
81
+ fn=run_t2tt,
82
+ inputs=[input_text, source_language, target_language],
83
+ outputs=output_text,
84
+ api_name="t2tt",
85
+ )
86
+
87
+ if __name__ == "__main__":
88
+ demo_t2tt.launch()
lang_list.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Language dict
2
+ language_code_to_name = {
3
+ "afr": "Afrikaans",
4
+ "amh": "Amharic",
5
+ "arb": "Modern Standard Arabic",
6
+ "ary": "Moroccan Arabic",
7
+ "arz": "Egyptian Arabic",
8
+ "asm": "Assamese",
9
+ "ast": "Asturian",
10
+ "azj": "North Azerbaijani",
11
+ "bel": "Belarusian",
12
+ "ben": "Bengali",
13
+ "bos": "Bosnian",
14
+ "bul": "Bulgarian",
15
+ "cat": "Catalan",
16
+ "ceb": "Cebuano",
17
+ "ces": "Czech",
18
+ "ckb": "Central Kurdish",
19
+ "cmn": "Mandarin Chinese",
20
+ "cym": "Welsh",
21
+ "dan": "Danish",
22
+ "deu": "German",
23
+ "ell": "Greek",
24
+ "eng": "English",
25
+ "est": "Estonian",
26
+ "eus": "Basque",
27
+ "fin": "Finnish",
28
+ "fra": "French",
29
+ "gaz": "West Central Oromo",
30
+ "gle": "Irish",
31
+ "glg": "Galician",
32
+ "guj": "Gujarati",
33
+ "heb": "Hebrew",
34
+ "hin": "Hindi",
35
+ "hrv": "Croatian",
36
+ "hun": "Hungarian",
37
+ "hye": "Armenian",
38
+ "ibo": "Igbo",
39
+ "ind": "Indonesian",
40
+ "isl": "Icelandic",
41
+ "ita": "Italian",
42
+ "jav": "Javanese",
43
+ "jpn": "Japanese",
44
+ "kam": "Kamba",
45
+ "kan": "Kannada",
46
+ "kat": "Georgian",
47
+ "kaz": "Kazakh",
48
+ "kea": "Kabuverdianu",
49
+ "khk": "Halh Mongolian",
50
+ "khm": "Khmer",
51
+ "kir": "Kyrgyz",
52
+ "kor": "Korean",
53
+ "lao": "Lao",
54
+ "lit": "Lithuanian",
55
+ "ltz": "Luxembourgish",
56
+ "lug": "Ganda",
57
+ "luo": "Luo",
58
+ "lvs": "Standard Latvian",
59
+ "mai": "Maithili",
60
+ "mal": "Malayalam",
61
+ "mar": "Marathi",
62
+ "mkd": "Macedonian",
63
+ "mlt": "Maltese",
64
+ "mni": "Meitei",
65
+ "mya": "Burmese",
66
+ "nld": "Dutch",
67
+ "nno": "Norwegian Nynorsk",
68
+ "nob": "Norwegian Bokm\u00e5l",
69
+ "npi": "Nepali",
70
+ "nya": "Nyanja",
71
+ "oci": "Occitan",
72
+ "ory": "Odia",
73
+ "pan": "Punjabi",
74
+ "pbt": "Southern Pashto",
75
+ "pes": "Western Persian",
76
+ "pol": "Polish",
77
+ "por": "Portuguese",
78
+ "ron": "Romanian",
79
+ "rus": "Russian",
80
+ "slk": "Slovak",
81
+ "slv": "Slovenian",
82
+ "sna": "Shona",
83
+ "snd": "Sindhi",
84
+ "som": "Somali",
85
+ "spa": "Spanish",
86
+ "srp": "Serbian",
87
+ "swe": "Swedish",
88
+ "swh": "Swahili",
89
+ "tam": "Tamil",
90
+ "tel": "Telugu",
91
+ "tgk": "Tajik",
92
+ "tgl": "Tagalog",
93
+ "tha": "Thai",
94
+ "tur": "Turkish",
95
+ "ukr": "Ukrainian",
96
+ "urd": "Urdu",
97
+ "uzn": "Northern Uzbek",
98
+ "vie": "Vietnamese",
99
+ "xho": "Xhosa",
100
+ "yor": "Yoruba",
101
+ "yue": "Cantonese",
102
+ "zlm": "Colloquial Malay",
103
+ "zsm": "Standard Malay",
104
+ "zul": "Zulu",
105
+ }
106
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
107
+
108
+ # Source langs: S2ST / S2TT / ASR don't need source lang
109
+ # T2TT / T2ST use this
110
+ text_source_language_codes = [
111
+ "afr",
112
+ "amh",
113
+ "arb",
114
+ "ary",
115
+ "arz",
116
+ "asm",
117
+ "azj",
118
+ "bel",
119
+ "ben",
120
+ "bos",
121
+ "bul",
122
+ "cat",
123
+ "ceb",
124
+ "ces",
125
+ "ckb",
126
+ "cmn",
127
+ "cym",
128
+ "dan",
129
+ "deu",
130
+ "ell",
131
+ "eng",
132
+ "est",
133
+ "eus",
134
+ "fin",
135
+ "fra",
136
+ "gaz",
137
+ "gle",
138
+ "glg",
139
+ "guj",
140
+ "heb",
141
+ "hin",
142
+ "hrv",
143
+ "hun",
144
+ "hye",
145
+ "ibo",
146
+ "ind",
147
+ "isl",
148
+ "ita",
149
+ "jav",
150
+ "jpn",
151
+ "kan",
152
+ "kat",
153
+ "kaz",
154
+ "khk",
155
+ "khm",
156
+ "kir",
157
+ "kor",
158
+ "lao",
159
+ "lit",
160
+ "lug",
161
+ "luo",
162
+ "lvs",
163
+ "mai",
164
+ "mal",
165
+ "mar",
166
+ "mkd",
167
+ "mlt",
168
+ "mni",
169
+ "mya",
170
+ "nld",
171
+ "nno",
172
+ "nob",
173
+ "npi",
174
+ "nya",
175
+ "ory",
176
+ "pan",
177
+ "pbt",
178
+ "pes",
179
+ "pol",
180
+ "por",
181
+ "ron",
182
+ "rus",
183
+ "slk",
184
+ "slv",
185
+ "sna",
186
+ "snd",
187
+ "som",
188
+ "spa",
189
+ "srp",
190
+ "swe",
191
+ "swh",
192
+ "tam",
193
+ "tel",
194
+ "tgk",
195
+ "tgl",
196
+ "tha",
197
+ "tur",
198
+ "ukr",
199
+ "urd",
200
+ "uzn",
201
+ "vie",
202
+ "yor",
203
+ "yue",
204
+ "zsm",
205
+ "zul",
206
+ ]
207
+ TEXT_SOURCE_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in text_source_language_codes])
208
+
209
+ # Target langs:
210
+ # S2ST / T2ST
211
+ s2st_target_language_codes = [
212
+ "eng",
213
+ "arb",
214
+ "ben",
215
+ "cat",
216
+ "ces",
217
+ "cmn",
218
+ "cym",
219
+ "dan",
220
+ "deu",
221
+ "est",
222
+ "fin",
223
+ "fra",
224
+ "hin",
225
+ "ind",
226
+ "ita",
227
+ "jpn",
228
+ "kor",
229
+ "mlt",
230
+ "nld",
231
+ "pes",
232
+ "pol",
233
+ "por",
234
+ "ron",
235
+ "rus",
236
+ "slk",
237
+ "spa",
238
+ "swe",
239
+ "swh",
240
+ "tel",
241
+ "tgl",
242
+ "tha",
243
+ "tur",
244
+ "ukr",
245
+ "urd",
246
+ "uzn",
247
+ "vie",
248
+ ]
249
+ S2ST_TARGET_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in s2st_target_language_codes])
250
+ T2ST_TARGET_LANGUAGE_NAMES = S2ST_TARGET_LANGUAGE_NAMES
251
+
252
+ # S2TT / T2TT / ASR
253
+ S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
254
+ T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
255
+ ASR_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES