davda54's picture
Update app.py
398f6f3
raw
history blame
5.82 kB
import torch
from transformers import AutoTokenizer, TextIteratorStreamer
# from modeling_nort5 import NorT5ForConditionalGeneration
from threading import Thread
print(f"Starting to load the model to memory")
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
eng_index = tokenizer.convert_tokens_to_ids(">>ENG<<")
nob_index = tokenizer.convert_tokens_to_ids(">>NOB<<")
nno_index = tokenizer.convert_tokens_to_ids(">>NNO<<")
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SYSTEM: Running on {device}", flush=True)
model = model.to(device)
model.eval()
print(f"Sucessfully loaded the model to the memory")
INITIAL_PROMPT = "Du er NorT5, en språkmodell laget ved Universitetet i Oslo. Du er en hjelpsom og ufarlig assistent som er glade for å hjelpe brukeren med enhver forespørsel."
TEMPERATURE = 0.7
SAMPLE = True
BEAMS = 1
PENALTY = 1.2
TOP_K = 64
TOP_P = 0.95
LANGUAGES = [
"🇬🇧 English",
"🇳🇴 Norwegian (Bokmål)",
"🇳🇴 Norwegian (Nynorsk)"
]
LANGUAGE_IDS = {
"🇬🇧 English": eng_index,
"🇳🇴 Norwegian (Bokmål)": nob_index,
"🇳🇴 Norwegian (Nynorsk)", nno_index
}
def set_default_target():
return "*Translating...*"
def translate(source, source_language, target_language):
if source_language == target_language:
return source
source_subwords = tokenizer(source).input_ids
source_subwords = [cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + source_subwords + [sep_index]
source_subwords = torch.tensor([source_subwords[:512]])
predictions = model.generate(
input_ids=source_subwords,
max_new_tokens = 512-1,
do_sample=False
)
predictions = [tokenizer.decode(p, skip_special_tokens=True) for p in predictions.tolist()]
return predictions
def switch_inputs(source, target, source_language, target_language):
return target, source, target_language, source_language
import gradio as gr
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
gr.Markdown("# Norwegian-English translation")
# gr.HTML('<img src="https://huggingface.co/ltg/norbert3-base/resolve/main/norbert.png" width=6.75%>')
# gr.Checkbox(label="I want to publish all my conversations", value=True)
# chatbot = gr.Chatbot(value=[[None, "Hei, hva kan jeg gjøre for deg? 😊"]])
with gr.Row():
with gr.Column(scale=7, variant="panel"):
source_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[0], show_label=False
)
source = gr.Textbox(
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True
) # .style(container=False)
submit = gr.Button("Submit", variant="primary") # .style(full_width=True)
# with gr.Column(scale=1, variant=None):
# switch = gr.Button("🔄")
with gr.Column(scale=7, variant="panel"):
target_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[1], show_label=False
)
target = gr.Textbox(
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100
)
def update_state_after_user():
return {
source: gr.update(interactive=False),
submit: gr.update(interactive=False),
source_language: gr.update(interactive=False),
target_language: gr.update(interactive=False)
}
def update_state_after_return():
return {
source: gr.update(interactive=True),
submit: gr.update(interactive=True),
source_language: gr.update(interactive=False),
target_language: gr.update(interactive=False)
}
submit_event = source.submit(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=set_default_target, inputs=[], outputs=[target], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
submit_click_event = submit.click(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=set_default_target, inputs=[], outputs=[target], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
# switch_event = switch.click(
# fn=switch_inputs, inputs=[source, target, source_language, target_language], outputs=[target, source, target_language, source_language], queue=False
# ).then(
# fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
# ).then(
# fn=set_default_target, inputs=[], outputs=[target], queue=False
# ).then(
# fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
# ).then(
# fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
# )
demo.queue(max_size=32, concurrency_count=2)
demo.launch()