import torch from transformers import AutoTokenizer, TextIteratorStreamer # from modeling_nort5 import NorT5ForConditionalGeneration from threading import Thread print(f"Starting to load the model to memory") tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base") cls_index = tokenizer.convert_tokens_to_ids("[CLS]") sep_index = tokenizer.convert_tokens_to_ids("[SEP]") eos_index = tokenizer.convert_tokens_to_ids("[EOS]") eng_index = tokenizer.convert_tokens_to_ids(">>ENG<<") nob_index = tokenizer.convert_tokens_to_ids(">>NOB<<") nno_index = tokenizer.convert_tokens_to_ids(">>NNO<<") model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"SYSTEM: Running on {device}", flush=True) model = model.to(device) model.eval() print(f"Sucessfully loaded the model to the memory") INITIAL_PROMPT = "Du er NorT5, en språkmodell laget ved Universitetet i Oslo. Du er en hjelpsom og ufarlig assistent som er glade for å hjelpe brukeren med enhver forespørsel." TEMPERATURE = 0.7 SAMPLE = True BEAMS = 1 PENALTY = 1.2 TOP_K = 64 TOP_P = 0.95 LANGUAGES = [ "🇬🇧 English", "🇳🇴 Norwegian (Bokmål)", "🇳🇴 Norwegian (Nynorsk)" ] LANGUAGE_IDS = { "🇬🇧 English": eng_index, "🇳🇴 Norwegian (Bokmål)": nob_index, "🇳🇴 Norwegian (Nynorsk)", nno_index } def set_default_target(): return "*Translating...*" def translate(source, source_language, target_language): if source_language == target_language: return source source_subwords = tokenizer(source).input_ids source_subwords = [cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + source_subwords + [sep_index] source_subwords = torch.tensor([source_subwords[:512]]) predictions = model.generate( input_ids=source_subwords, max_new_tokens = 512-1, do_sample=False ) predictions = [tokenizer.decode(p, skip_special_tokens=True) for p in predictions.tolist()] return predictions def switch_inputs(source, target, source_language, target_language): return target, source, target_language, source_language import gradio as gr with gr.Blocks(theme='sudeepshouche/minimalist') as demo: gr.Markdown("# Norwegian-English translation") # gr.HTML('') # gr.Checkbox(label="I want to publish all my conversations", value=True) # chatbot = gr.Chatbot(value=[[None, "Hei, hva kan jeg gjøre for deg? 😊"]]) with gr.Row(): with gr.Column(scale=7, variant="panel"): source_language = gr.Dropdown( LANGUAGES, value=LANGUAGES[0], show_label=False ) source = gr.Textbox( label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True ) # .style(container=False) submit = gr.Button("Submit", variant="primary") # .style(full_width=True) # with gr.Column(scale=1, variant=None): # switch = gr.Button("🔄") with gr.Column(scale=7, variant="panel"): target_language = gr.Dropdown( LANGUAGES, value=LANGUAGES[1], show_label=False ) target = gr.Textbox( label="Translation", show_label=False, interactive=False, lines=7, max_lines=100 ) def update_state_after_user(): return { source: gr.update(interactive=False), submit: gr.update(interactive=False), source_language: gr.update(interactive=False), target_language: gr.update(interactive=False) } def update_state_after_return(): return { source: gr.update(interactive=True), submit: gr.update(interactive=True), source_language: gr.update(interactive=False), target_language: gr.update(interactive=False) } submit_event = source.submit( fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ).then( fn=set_default_target, inputs=[], outputs=[target], queue=False ).then( fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True ).then( fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ) submit_click_event = submit.click( fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ).then( fn=set_default_target, inputs=[], outputs=[target], queue=False ).then( fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True ).then( fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ) # switch_event = switch.click( # fn=switch_inputs, inputs=[source, target, source_language, target_language], outputs=[target, source, target_language, source_language], queue=False # ).then( # fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False # ).then( # fn=set_default_target, inputs=[], outputs=[target], queue=False # ).then( # fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True # ).then( # fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False # ) demo.queue(max_size=32, concurrency_count=2) demo.launch()