davda54's picture
Update app.py
f0d928c
raw
history blame
7.04 kB
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
from transformers.generation import LogitsProcessor
from threading import Thread
import gradio as gr
print(f"Starting to load the model to memory")
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
pad_index = tokenizer.convert_tokens_to_ids("[PAD]")
eng_index = tokenizer.convert_tokens_to_ids(">>eng<<")
nob_index = tokenizer.convert_tokens_to_ids(">>nob<<")
nno_index = tokenizer.convert_tokens_to_ids(">>nno<<")
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SYSTEM: Running on {device}", flush=True)
model = model.to(device)
model.eval()
print(f"Sucessfully loaded the model to the memory")
LANGUAGES = [
"πŸ‡¬πŸ‡§ English",
"πŸ‡³πŸ‡΄ Norwegian (BokmΓ₯l)",
"πŸ‡³πŸ‡΄ Norwegian (Nynorsk)"
]
LANGUAGE_IDS = {
"πŸ‡¬πŸ‡§ English": eng_index,
"πŸ‡³πŸ‡΄ Norwegian (BokmΓ₯l)": nob_index,
"πŸ‡³πŸ‡΄ Norwegian (Nynorsk)": nno_index
}
class BatchStreamer(TextIteratorStreamer):
def put(self, value):
print(value.shape)
#if value.size(0) == 1:
# return super().put(value)
if len(self.token_cache) == 0:
self.token_cache = [[] for _ in range(value.size(0))]
value = value.tolist()
# Add the new token to the cache and decodes the entire thing.
for c, v in zip(self.token_cache, value):
c += [v] if isinstance(v, int) else v
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
text = '\n'.join(paragraphs)
self.on_finalized_text(text)
def end(self):
if len(self.token_cache) > 0:
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
printable_text = '\n'.join(paragraphs)
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float, model):
last_bias = model.classifier.nonlinearity[-1].bias.data
last_bias = torch.nn.functional.log_softmax(last_bias)
self.penalty = penalty * (last_bias - last_bias.max())
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids)
scores.scatter_(1, input_ids, penalized_score)
return scores
def translate(source, source_language, target_language):
if source_language == target_language:
yield source.strip()
return source.strip()
source = [s.strip() for s in source.split('\n')]
source_subwords = tokenizer(source).input_ids
source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords]
source_subwords = [torch.tensor(s) for s in source_subwords]
source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index)
source_subwords = source_subwords[:, :512].to(device)
streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True)
def generate(model, **kwargs):
with torch.inference_mode():
with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16):
return model.generate(**kwargs)
generate_kwargs = dict(
streamer=streamer,
input_ids=source_subwords,
attention_mask=(source_subwords != pad_index).long(),
max_new_tokens = 512-1,
#top_k=64,
#top_p=0.95,
#do_sample=True,
#temperature=0.3,
num_beams=1,
#use_cache=True,
logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)],
# num_beams=4,
# early_stopping=True,
do_sample=False,
use_cache=True
)
t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
t.start()
for new_text in streamer:
yield new_text.strip()
return new_text.strip()
def switch_inputs(source, target, source_language, target_language):
return target, source, target_language, source_language
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
gr.Markdown("# Norwegian-English translation")
with gr.Row():
with gr.Column(scale=7, variant="panel"):
source_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[1], show_label=False
)
source = gr.Textbox(
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True
) # .style(container=False)
submit = gr.Button("Submit", variant="primary") # .style(full_width=True)
with gr.Column(scale=7, variant="panel"):
target_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[0], show_label=False
)
target = gr.Textbox(
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100
)
def update_state_after_user():
return {
source: gr.update(interactive=False),
submit: gr.update(interactive=False),
source_language: gr.update(interactive=False),
target_language: gr.update(interactive=False)
}
def update_state_after_return():
return {
source: gr.update(interactive=True),
submit: gr.update(interactive=True),
source_language: gr.update(interactive=True),
target_language: gr.update(interactive=True)
}
submit_event = source.submit(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
submit_click_event = submit.click(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
demo.queue(max_size=32, concurrency_count=2)
demo.launch()