Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, TextIteratorStreamer | |
# from modeling_nort5 import NorT5ForConditionalGeneration | |
from threading import Thread | |
print(f"Starting to load the model to memory") | |
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base") | |
cls_index = tokenizer.convert_tokens_to_ids("[CLS]") | |
sep_index = tokenizer.convert_tokens_to_ids("[SEP]") | |
eos_index = tokenizer.convert_tokens_to_ids("[EOS]") | |
eng_index = tokenizer.convert_tokens_to_ids(">>ENG<<") | |
nob_index = tokenizer.convert_tokens_to_ids(">>NOB<<") | |
nno_index = tokenizer.convert_tokens_to_ids(">>NNO<<") | |
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"SYSTEM: Running on {device}", flush=True) | |
model = model.to(device) | |
model.eval() | |
print(f"Sucessfully loaded the model to the memory") | |
INITIAL_PROMPT = "Du er NorT5, en språkmodell laget ved Universitetet i Oslo. Du er en hjelpsom og ufarlig assistent som er glade for å hjelpe brukeren med enhver forespørsel." | |
TEMPERATURE = 0.7 | |
SAMPLE = True | |
BEAMS = 1 | |
PENALTY = 1.2 | |
TOP_K = 64 | |
TOP_P = 0.95 | |
LANGUAGES = [ | |
"🇬🇧 English", | |
"🇳🇴 Norwegian (Bokmål)", | |
"🇳🇴 Norwegian (Nynorsk)" | |
] | |
LANGUAGE_IDS = { | |
"🇬🇧 English": eng_index, | |
"🇳🇴 Norwegian (Bokmål)": nob_index, | |
"🇳🇴 Norwegian (Nynorsk)", nno_index | |
} | |
def set_default_target(): | |
return "*Translating...*" | |
def translate(source, source_language, target_language): | |
if source_language == target_language: | |
return source | |
source_subwords = tokenizer(source).input_ids | |
source_subwords = [cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + source_subwords + [sep_index] | |
source_subwords = torch.tensor([source_subwords[:512]]) | |
predictions = model.generate( | |
input_ids=source_subwords, | |
max_new_tokens = 512-1, | |
do_sample=False | |
) | |
predictions = [tokenizer.decode(p, skip_special_tokens=True) for p in predictions.tolist()] | |
return predictions | |
def switch_inputs(source, target, source_language, target_language): | |
return target, source, target_language, source_language | |
import gradio as gr | |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo: | |
gr.Markdown("# Norwegian-English translation") | |
# gr.HTML('<img src="https://huggingface.co/ltg/norbert3-base/resolve/main/norbert.png" width=6.75%>') | |
# gr.Checkbox(label="I want to publish all my conversations", value=True) | |
# chatbot = gr.Chatbot(value=[[None, "Hei, hva kan jeg gjøre for deg? 😊"]]) | |
with gr.Row(): | |
with gr.Column(scale=7, variant="panel"): | |
source_language = gr.Dropdown( | |
LANGUAGES, value=LANGUAGES[0], show_label=False | |
) | |
source = gr.Textbox( | |
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True | |
) # .style(container=False) | |
submit = gr.Button("Submit", variant="primary") # .style(full_width=True) | |
# with gr.Column(scale=1, variant=None): | |
# switch = gr.Button("🔄") | |
with gr.Column(scale=7, variant="panel"): | |
target_language = gr.Dropdown( | |
LANGUAGES, value=LANGUAGES[1], show_label=False | |
) | |
target = gr.Textbox( | |
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100 | |
) | |
def update_state_after_user(): | |
return { | |
source: gr.update(interactive=False), | |
submit: gr.update(interactive=False), | |
source_language: gr.update(interactive=False), | |
target_language: gr.update(interactive=False) | |
} | |
def update_state_after_return(): | |
return { | |
source: gr.update(interactive=True), | |
submit: gr.update(interactive=True), | |
source_language: gr.update(interactive=False), | |
target_language: gr.update(interactive=False) | |
} | |
submit_event = source.submit( | |
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
).then( | |
fn=set_default_target, inputs=[], outputs=[target], queue=False | |
).then( | |
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True | |
).then( | |
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
) | |
submit_click_event = submit.click( | |
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
).then( | |
fn=set_default_target, inputs=[], outputs=[target], queue=False | |
).then( | |
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True | |
).then( | |
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
) | |
# switch_event = switch.click( | |
# fn=switch_inputs, inputs=[source, target, source_language, target_language], outputs=[target, source, target_language, source_language], queue=False | |
# ).then( | |
# fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
# ).then( | |
# fn=set_default_target, inputs=[], outputs=[target], queue=False | |
# ).then( | |
# fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True | |
# ).then( | |
# fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
# ) | |
demo.queue(max_size=32, concurrency_count=2) | |
demo.launch() | |