File size: 5,823 Bytes
f61f538
 
8cabb8a
f61f538
 
 
398f6f3
f61f538
398f6f3
 
 
 
 
 
 
f61f538
398f6f3
f61f538
 
 
 
398f6f3
 
f61f538
398f6f3
f61f538
 
 
 
 
 
 
 
 
 
6ed95ae
 
 
 
 
 
398f6f3
 
 
 
 
 
6ed95ae
 
 
 
f61f538
 
398f6f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f61f538
 
6ed95ae
 
 
 
f61f538
 
6ed95ae
95270f3
f61f538
 
 
 
 
 
 
6ed95ae
f61f538
6ed95ae
f61f538
 
6ed95ae
f61f538
 
 
562a084
 
6ed95ae
 
f61f538
6ed95ae
f61f538
 
95270f3
f61f538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ed95ae
 
 
f61f538
 
 
6ed95ae
f61f538
 
 
6ed95ae
 
 
 
 
 
 
 
 
562a084
 
 
 
 
 
 
 
 
 
 
f61f538
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from transformers import AutoTokenizer, TextIteratorStreamer
# from modeling_nort5 import NorT5ForConditionalGeneration
from threading import Thread


print(f"Starting to load the model to memory")

tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
eng_index = tokenizer.convert_tokens_to_ids(">>ENG<<")
nob_index = tokenizer.convert_tokens_to_ids(">>NOB<<")
nno_index = tokenizer.convert_tokens_to_ids(">>NNO<<")

model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SYSTEM: Running on {device}", flush=True)

model = model.to(device)
model.eval()

print(f"Sucessfully loaded the model to the memory")


INITIAL_PROMPT = "Du er NorT5, en språkmodell laget ved Universitetet i Oslo. Du er en hjelpsom og ufarlig assistent som er glade for å hjelpe brukeren med enhver forespørsel."
TEMPERATURE = 0.7
SAMPLE = True
BEAMS = 1
PENALTY = 1.2
TOP_K = 64
TOP_P = 0.95

LANGUAGES = [
    "🇬🇧 English",
    "🇳🇴 Norwegian (Bokmål)",
    "🇳🇴 Norwegian (Nynorsk)"
]

LANGUAGE_IDS = {
    "🇬🇧 English": eng_index,
    "🇳🇴 Norwegian (Bokmål)": nob_index,
    "🇳🇴 Norwegian (Nynorsk)", nno_index
}


def set_default_target():
    return "*Translating...*"


def translate(source, source_language, target_language):
    if source_language == target_language:
        return source

    source_subwords = tokenizer(source).input_ids
    source_subwords = [cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + source_subwords + [sep_index]
    source_subwords = torch.tensor([source_subwords[:512]])

    predictions = model.generate(
        input_ids=source_subwords,
        max_new_tokens = 512-1,
        do_sample=False
    )
    predictions = [tokenizer.decode(p, skip_special_tokens=True) for p in predictions.tolist()]

    return predictions


def switch_inputs(source, target, source_language, target_language):
    return target, source, target_language, source_language


import gradio as gr

with gr.Blocks(theme='sudeepshouche/minimalist') as demo:

    gr.Markdown("# Norwegian-English translation")
    # gr.HTML('<img src="https://huggingface.co/ltg/norbert3-base/resolve/main/norbert.png" width=6.75%>')
    # gr.Checkbox(label="I want to publish all my conversations", value=True)

    # chatbot = gr.Chatbot(value=[[None, "Hei, hva kan jeg gjøre for deg? 😊"]])

    with gr.Row():
        with gr.Column(scale=7, variant="panel"):
            source_language = gr.Dropdown(
                LANGUAGES, value=LANGUAGES[0], show_label=False
            )
            source = gr.Textbox(
                label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True
            )  # .style(container=False)
            submit = gr.Button("Submit", variant="primary")  # .style(full_width=True)

        # with gr.Column(scale=1, variant=None):
        #     switch = gr.Button("🔄")

        with gr.Column(scale=7, variant="panel"):
            target_language = gr.Dropdown(
                LANGUAGES, value=LANGUAGES[1], show_label=False
            )
            target = gr.Textbox(
                label="Translation", show_label=False, interactive=False, lines=7, max_lines=100
            )


    def update_state_after_user():
        return {
            source: gr.update(interactive=False),
            submit: gr.update(interactive=False),
            source_language: gr.update(interactive=False),
            target_language: gr.update(interactive=False)
        }

    def update_state_after_return():
        return {
            source: gr.update(interactive=True),
            submit: gr.update(interactive=True),
            source_language: gr.update(interactive=False),
            target_language: gr.update(interactive=False)
        }


    submit_event = source.submit(
        fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
    ).then(
        fn=set_default_target, inputs=[], outputs=[target], queue=False
    ).then(
        fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
    ).then(
        fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
    )
    
    submit_click_event = submit.click(
        fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
    ).then(
        fn=set_default_target, inputs=[], outputs=[target], queue=False
    ).then(
        fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
    ).then(
        fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
    )

    # switch_event = switch.click(
    #     fn=switch_inputs, inputs=[source, target, source_language, target_language], outputs=[target, source, target_language, source_language], queue=False
    # ).then(
    #     fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
    # ).then(
    #     fn=set_default_target, inputs=[], outputs=[target], queue=False
    # ).then(
    #     fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
    # ).then(
    #     fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
    # )

demo.queue(max_size=32, concurrency_count=2)
demo.launch()