Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer | |
from transformers.generation import LogitsProcessor | |
from threading import Thread | |
import gradio as gr | |
print(f"Starting to load the model to memory") | |
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base") | |
cls_index = tokenizer.convert_tokens_to_ids("[CLS]") | |
sep_index = tokenizer.convert_tokens_to_ids("[SEP]") | |
eos_index = tokenizer.convert_tokens_to_ids("[EOS]") | |
pad_index = tokenizer.convert_tokens_to_ids("[PAD]") | |
eng_index = tokenizer.convert_tokens_to_ids(">>eng<<") | |
nob_index = tokenizer.convert_tokens_to_ids(">>nob<<") | |
nno_index = tokenizer.convert_tokens_to_ids(">>nno<<") | |
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"SYSTEM: Running on {device}", flush=True) | |
model = model.to(device) | |
model.eval() | |
print(f"Sucessfully loaded the model to the memory") | |
LANGUAGES = [ | |
"π¬π§ English", | |
"π³π΄ Norwegian (BokmΓ₯l)", | |
"π³π΄ Norwegian (Nynorsk)" | |
] | |
LANGUAGE_IDS = { | |
"π¬π§ English": eng_index, | |
"π³π΄ Norwegian (BokmΓ₯l)": nob_index, | |
"π³π΄ Norwegian (Nynorsk)": nno_index | |
} | |
class BatchStreamer(TextIteratorStreamer): | |
def put(self, value): | |
print(value.shape) | |
#if value.size(0) == 1: | |
# return super().put(value) | |
if len(self.token_cache) == 0: | |
self.token_cache = [[] for _ in range(value.size(0))] | |
value = value.tolist() | |
# Add the new token to the cache and decodes the entire thing. | |
for c, v in zip(self.token_cache, value): | |
c += [v] if isinstance(v, int) else v | |
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache] | |
text = '\n'.join(paragraphs) | |
self.on_finalized_text(text) | |
def end(self): | |
if len(self.token_cache) > 0: | |
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache] | |
printable_text = '\n'.join(paragraphs) | |
self.token_cache = [] | |
self.print_len = 0 | |
else: | |
printable_text = "" | |
self.next_tokens_are_prompt = True | |
self.on_finalized_text(printable_text, stream_end=True) | |
class RepetitionPenaltyLogitsProcessor(LogitsProcessor): | |
def __init__(self, penalty: float, model): | |
last_bias = model.classifier.nonlinearity[-1].bias.data | |
last_bias = torch.nn.functional.log_softmax(last_bias) | |
self.penalty = penalty * (last_bias - last_bias.max()) | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids) | |
scores.scatter_(1, input_ids, penalized_score) | |
return scores | |
def translate(source, source_language, target_language): | |
if source_language == target_language: | |
yield source.strip() | |
return source.strip() | |
source = [s.strip() for s in source.split('\n')] | |
source_subwords = tokenizer(source).input_ids | |
source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords] | |
source_subwords = [torch.tensor(s) for s in source_subwords] | |
source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index) | |
source_subwords = source_subwords[:, :512].to(device) | |
streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True) | |
def generate(model, **kwargs): | |
with torch.inference_mode(): | |
with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16): | |
return model.generate(**kwargs) | |
generate_kwargs = dict( | |
streamer=streamer, | |
input_ids=source_subwords, | |
attention_mask=(source_subwords != pad_index).long(), | |
max_new_tokens = 512-1, | |
#top_k=64, | |
#top_p=0.95, | |
#do_sample=True, | |
#temperature=0.3, | |
num_beams=1, | |
#use_cache=True, | |
logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)], | |
# num_beams=4, | |
# early_stopping=True, | |
do_sample=False, | |
use_cache=True | |
) | |
t = Thread(target=generate, args=(model,), kwargs=generate_kwargs) | |
t.start() | |
for new_text in streamer: | |
yield new_text.strip() | |
return new_text.strip() | |
def switch_inputs(source, target, source_language, target_language): | |
return target, source, target_language, source_language | |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo: | |
gr.Markdown("# Norwegian-English translation") | |
with gr.Row(): | |
with gr.Column(scale=7, variant="panel"): | |
source_language = gr.Dropdown( | |
LANGUAGES, value=LANGUAGES[1], show_label=False | |
) | |
source = gr.Textbox( | |
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True | |
) # .style(container=False) | |
submit = gr.Button("Submit", variant="primary") # .style(full_width=True) | |
with gr.Column(scale=7, variant="panel"): | |
target_language = gr.Dropdown( | |
LANGUAGES, value=LANGUAGES[0], show_label=False | |
) | |
target = gr.Textbox( | |
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100 | |
) | |
def update_state_after_user(): | |
return { | |
source: gr.update(interactive=False), | |
submit: gr.update(interactive=False), | |
source_language: gr.update(interactive=False), | |
target_language: gr.update(interactive=False) | |
} | |
def update_state_after_return(): | |
return { | |
source: gr.update(interactive=True), | |
submit: gr.update(interactive=True), | |
source_language: gr.update(interactive=True), | |
target_language: gr.update(interactive=True) | |
} | |
submit_event = source.submit( | |
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
).then( | |
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True | |
).then( | |
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
) | |
submit_click_event = submit.click( | |
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
).then( | |
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True | |
).then( | |
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False | |
) | |
demo.queue(max_size=32, concurrency_count=2) | |
demo.launch() | |