souvorinkg commited on
Commit
0c7fa36
·
verified ·
1 Parent(s): c21898a

Tried using new translate function

Browse files
Files changed (1) hide show
  1. app.py +70 -15
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import gradio as gr
3
 
4
 
@@ -8,12 +8,65 @@ import gradio as gr
8
  # tokenizer_en_to_kin = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="kin_Latn")
9
  #tokenizer_ses_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="ses_Latn", tgt_lang="eng_Latn")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("souvorinkg/eng-ses-nllb", token=False).half()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Load the tokenizer and model for Kinyarwanda to English
13
  # tokenizer_kin_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="kin_Latn", tgt_lang="eng_Latn")
14
 
15
- tokenizer_en_to_ses = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="ses_Latn")
16
- tokenizer_tsn_to_eng = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="tsn_Latn", tgt_lang="eng_Latn")
17
 
18
 
19
 
@@ -29,15 +82,15 @@ tokenizer_tsn_to_eng = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb",
29
  # translated_tokens = model.generate(**inputs, max_length=30, no_repeat_ngram_size=2)
30
  # return tokenizer_kin_to_en.batch_decode(translated_tokens, skip_special_tokens=True)[0]
31
 
32
- def translate_en_to_ses(SourceText):
33
- inputs = tokenizer_en_to_ses(SourceText, return_tensors="pt")
34
- translated_tokens = model.generate(**inputs, max_length=30)
35
- return tokenizer_tsn_to_eng.batch_decode(translated_tokens, skip_special_tokens=True)[0]
36
 
37
- def translate_ses_to_en(SourceText):
38
- inputs = inputs = tokenizer_tsn_to_eng(SourceText, return_tensors="pt")
39
- translated_tokens = model.generate(**inputs, max_length=30)
40
- return tokenizer_en_to_ses.batch_decode(translated_tokens, skip_special_tokens=True)[0]
41
 
42
  # def translate_en_to_tsn(SourceText):
43
  # inputs = tokenizer_en_to_tsn(SourceText, return_tensors="pt")
@@ -45,21 +98,23 @@ def translate_ses_to_en(SourceText):
45
  # return tokenizer_en_to_tsn.batch_decode(translated_tokens, skip_special_tokens=True)[0]
46
 
47
  # Function to handle dropdown selection and call the appropriate translation function
48
- def translate(SourceText, direction):
49
  # if direction == "English to Kinyarwanda":
50
  # return translate_en_to_kin(SourceText)
51
  # if direction == "Kinyarwanda to English":
52
  # return translate_kin_to_en(SourceText)
53
  if direction == "English to Sesotho":
54
- return translate_en_to_ses(SourceText)
 
55
  if direction == "Sesotho to English":
56
- return translate_ses_to_en(SourceText)
 
57
  # if direction == "English to Tswana":
58
  # return translate == translate_en_to_tsn(SourceText)
59
 
60
  # Create the Gradio interface
61
  iface = gr.Interface(
62
- fn=translate,
63
  inputs=[gr.Textbox(lines=2, label="Input Text"), gr.Dropdown(["English to Sesotho", "Sesotho to English"], label="Translation Direction")],
64
  outputs="text",
65
  title="Bilingual Translator",
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, NllbTokenizer
2
  import gradio as gr
3
 
4
 
 
8
  # tokenizer_en_to_kin = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="kin_Latn")
9
  #tokenizer_ses_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="ses_Latn", tgt_lang="eng_Latn")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("souvorinkg/eng-ses-nllb", token=False).half()
11
+ tokenizer = NllbTokenizer.from_pretrained("souvorinkg/eng-ses-nllb")
12
+
13
+ def fix_tokenizer(tokenizer, new_lang='ses_Latn'):
14
+ """
15
+ Add a new language token to the tokenizer vocabulary
16
+ (this should be done each time after its initialization)
17
+ """
18
+ old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
19
+ tokenizer.lang_code_to_id[new_lang] = old_len-1
20
+ tokenizer.id_to_lang_code[old_len-1] = new_lang
21
+ # always move "mask" to the last position
22
+ tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
23
+
24
+ tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
25
+ tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
26
+ if new_lang not in tokenizer._additional_special_tokens:
27
+ tokenizer._additional_special_tokens.append(new_lang)
28
+ # clear the added token encoder; otherwise a new token may end up there by mistake
29
+ tokenizer.added_tokens_encoder = {}
30
+ tokenizer.added_tokens_decoder = {}
31
+
32
+ fix_tokenizer(tokenizer)
33
+ model.resize_token_embeddings(len(tokenizer))
34
+
35
+ def translate(
36
+ text, src_lang, tgt_lang,
37
+ a=32, b=3, max_input_length=1024, num_beams=4, **kwargs
38
+ ):
39
+ """Turn a text or a list of texts into a list of translations"""
40
+ tokenizer.src_lang = src_lang
41
+ tokenizer.tgt_lang = tgt_lang
42
+ inputs = tokenizer(
43
+ text, return_tensors='pt', padding=True, truncation=True,
44
+ max_length=max_input_length
45
+ )
46
+ model.eval() # turn off training mode
47
+ result = model.generate(
48
+ **inputs.to(model.device),
49
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
50
+ max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
51
+ num_beams=num_beams, **kwargs
52
+ )
53
+ return tokenizer.batch_decode(result, skip_special_tokens=True)
54
+
55
+ # fixing the new/moved token embeddings in the model
56
+ added_token_id = tokenizer.convert_tokens_to_ids('ses_Latn')
57
+ similar_lang_id = tokenizer.convert_tokens_to_ids('tsn_Latn')
58
+ embeds = model.model.shared.weight.data
59
+ # moving the embedding for "mask" to its new position
60
+ embeds[added_token_id+1] =embeds[added_token_id]
61
+ # initializing new language token with a token of a similar language
62
+ embeds[added_token_id] = embeds[similar_lang_id]
63
+
64
 
65
  # Load the tokenizer and model for Kinyarwanda to English
66
  # tokenizer_kin_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="kin_Latn", tgt_lang="eng_Latn")
67
 
68
+ #tokenizer_en_to_ses = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="ses_Latn")
69
+ #tokenizer_tsn_to_eng = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="tsn_Latn", tgt_lang="eng_Latn")
70
 
71
 
72
 
 
82
  # translated_tokens = model.generate(**inputs, max_length=30, no_repeat_ngram_size=2)
83
  # return tokenizer_kin_to_en.batch_decode(translated_tokens, skip_special_tokens=True)[0]
84
 
85
+ # def translate_en_to_ses(SourceText):
86
+ # inputs = tokenizer_en_to_ses(SourceText, return_tensors="pt")
87
+ # translated_tokens = model.generate(**inputs, max_length=30)
88
+ # return tokenizer_tsn_to_eng.batch_decode(translated_tokens, skip_special_tokens=True)[0]
89
 
90
+ # def translate_ses_to_en(SourceText):
91
+ # inputs = inputs = tokenizer_tsn_to_eng(SourceText, return_tensors="pt")
92
+ # translated_tokens = model.generate(**inputs, max_length=30)
93
+ # return tokenizer_en_to_ses.batch_decode(translated_tokens, skip_special_tokens=True)[0]
94
 
95
  # def translate_en_to_tsn(SourceText):
96
  # inputs = tokenizer_en_to_tsn(SourceText, return_tensors="pt")
 
98
  # return tokenizer_en_to_tsn.batch_decode(translated_tokens, skip_special_tokens=True)[0]
99
 
100
  # Function to handle dropdown selection and call the appropriate translation function
101
+ def translateIn(SourceText, direction):
102
  # if direction == "English to Kinyarwanda":
103
  # return translate_en_to_kin(SourceText)
104
  # if direction == "Kinyarwanda to English":
105
  # return translate_kin_to_en(SourceText)
106
  if direction == "English to Sesotho":
107
+ text = translate(text=SourceText, src_lang='eng_Latn', tgt_lang='ses_Latn')
108
+ return text
109
  if direction == "Sesotho to English":
110
+ text = translate(text=SourceText, src_lang='tsn_Latn', tgt_lang='eng_Latn')
111
+ return text
112
  # if direction == "English to Tswana":
113
  # return translate == translate_en_to_tsn(SourceText)
114
 
115
  # Create the Gradio interface
116
  iface = gr.Interface(
117
+ fn=translateIn,
118
  inputs=[gr.Textbox(lines=2, label="Input Text"), gr.Dropdown(["English to Sesotho", "Sesotho to English"], label="Translation Direction")],
119
  outputs="text",
120
  title="Bilingual Translator",