File size: 3,751 Bytes
e45d7fa
 
 
54fb600
e45d7fa
 
 
54fb600
 
e45d7fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54fb600
bc7c963
 
54fb600
 
e45d7fa
 
 
 
bc7c963
 
54fb600
 
e45d7fa
 
 
 
 
f515ff9
 
 
 
00c1095
f515ff9
e45d7fa
 
 
 
 
 
 
 
 
54fb600
e45d7fa
 
 
 
 
 
 
 
 
 
 
f515ff9
e45d7fa
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
MODEL_URL = 'slone/nllb-rus-tyv-v2-extvoc'


lang_to_code = {
    'Орус | Русский | Russian': 'rus_Cyrl',
    'Тыва | Тувинский | Tyvan': 'tyv_Cyrl',
}


def fix_tokenizer(tokenizer, new_lang='tyv_Cyrl'):
    """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """
    old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
    tokenizer.lang_code_to_id[new_lang] = old_len-1
    tokenizer.id_to_lang_code[old_len-1] = new_lang
    # always move "mask" to the last position
    tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset

    tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
    tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
    if new_lang not in tokenizer._additional_special_tokens:
        tokenizer._additional_special_tokens.append(new_lang)
    # clear the added token encoder; otherwise a new token may end up there by mistake
    tokenizer.added_tokens_encoder = {}
    tokenizer.added_tokens_decoder = {}


def translate(
        text,
        model,
        tokenizer,
        src_lang='rus_Cyrl',
        tgt_lang='tyv_Cyrl',
        max_length='auto',
        num_beams=4,
        no_repeat_ngram_size=4,
        n_out=None,
        **kwargs
):
    tokenizer.src_lang = src_lang
    encoded = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    if max_length == 'auto':
        max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
    model.eval()
    generated_tokens = model.generate(
        **encoded.to(model.device),
        forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
        max_length=max_length,
        num_beams=num_beams,
        no_repeat_ngram_size=no_repeat_ngram_size,
        num_return_sequences=n_out or 1,
        **kwargs
    )
    out = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    if isinstance(text, str) and n_out is None:
        return out[0]
    return out


def translate_wrapper(text, src, trg, random=False):
    src_lang = lang_to_code.get(src)
    tgt_lang = lang_to_code.get(trg)
    # if src == trg:
    #     return 'Please choose two different languages'
    result = translate(
        text=text,
        model=model,
        tokenizer=tokenizer,
        src_lang=src_lang,
        tgt_lang=tgt_lang,
        do_sample=random,
        num_beams=1 if random else 4,
    )
    return result


article = """
This is a NLLB-200-600M model fine-tuned for translation between Russian and Tyvan (Tuvan) languages,
using the data from https://tyvan.ru/.

**More details will be published soon!**

__Please translate one sentence at a time; the model is not working adequately with multiple sentences!__
"""


interface = gr.Interface(
    translate_wrapper,
    [
        gr.Textbox(label="Text", lines=2, placeholder='text to translate '),
        gr.Dropdown(list(lang_to_code.keys()), type="value", label='source language', value=list(lang_to_code.keys())[0]),
        gr.Dropdown(list(lang_to_code.keys()), type="value", label='target language', value=list(lang_to_code.keys())[1]),
        gr.Checkbox(label="random", value=False),
    ],
    "text",
    title='Tyvan-Russian translaton',
    article=article,
)


if __name__ == '__main__':
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL)
    if torch.cuda.is_available():
        model.cuda()
    tokenizer = NllbTokenizer.from_pretrained(MODEL_URL, force_download=True)
    fix_tokenizer(tokenizer)

    interface.launch()