davda54's picture
Update app.py
9c65c2d verified
raw
history blame
8.1 kB
import os
import json
import torch
import shutil
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
from transformers.generation import LogitsProcessor
import huggingface_hub
from huggingface_hub import Repository
from threading import Thread
import gradio as gr
print(f"Starting to load the model to memory")
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
pad_index = tokenizer.convert_tokens_to_ids("[PAD]")
eng_index = tokenizer.convert_tokens_to_ids(">>eng<<")
nob_index = tokenizer.convert_tokens_to_ids(">>nob<<")
nno_index = tokenizer.convert_tokens_to_ids(">>nno<<")
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SYSTEM: Running on {device}", flush=True)
model = model.to(device)
model.eval()
print(f"Sucessfully loaded the model to the memory")
LANGUAGES = [
"🇬🇧 English",
"🇳🇴 Norwegian (Bokmål)",
"🇳🇴 Norwegian (Nynorsk)"
]
LANGUAGE_IDS = {
"🇬🇧 English": eng_index,
"🇳🇴 Norwegian (Bokmål)": nob_index,
"🇳🇴 Norwegian (Nynorsk)": nno_index
}
STATS_REPO = "https://huggingface.co/datasets/ltg/usage_statistics"
HF_TOKEN = os.environ.get("HF_TOKEN")
dataset = Repository(
local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN
)
# log the timestamp of the query
def add_anonymous_usage_log(path):
global dataset
try:
dataset.git_pull()
with open(path, "a") as f:
line = json.dumps(str(datetime.now()), ensure_ascii=False)
f.write(f"{line}\n")
dataset.push_to_hub(blocking=False)
except:
shutil.rmtree("data")
dataset = Repository(
local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN
)
with open(path, "a") as f:
line = json.dumps(str(datetime.now()), ensure_ascii=False)
f.write(f"{line}\n")
dataset.push_to_hub(blocking=False)
class BatchStreamer(TextIteratorStreamer):
def put(self, value):
print(value.shape)
#if value.size(0) == 1:
# return super().put(value)
if len(self.token_cache) == 0:
self.token_cache = [[] for _ in range(value.size(0))]
value = value.tolist()
# Add the new token to the cache and decodes the entire thing.
for c, v in zip(self.token_cache, value):
c += [v] if isinstance(v, int) else v
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
text = '\n'.join(paragraphs)
self.on_finalized_text(text)
def end(self):
if len(self.token_cache) > 0:
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
printable_text = '\n'.join(paragraphs)
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float, model):
last_bias = model.classifier.nonlinearity[-1].bias.data
last_bias = torch.nn.functional.log_softmax(last_bias)
self.penalty = penalty * (last_bias - last_bias.max())
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids)
scores.scatter_(1, input_ids, penalized_score)
return scores
def translate(source, source_language, target_language):
if source_language == target_language:
yield source.strip()
return source.strip()
source = [s.strip() for s in source.split('\n')]
source_subwords = tokenizer(source).input_ids
source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords]
source_subwords = [torch.tensor(s) for s in source_subwords]
source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index)
source_subwords = source_subwords[:, :512].to(device)
streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True)
def generate(model, **kwargs):
with torch.inference_mode():
with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16):
return model.generate(**kwargs)
generate_kwargs = dict(
streamer=streamer,
input_ids=source_subwords,
attention_mask=(source_subwords != pad_index).long(),
max_new_tokens = 512-1,
#top_k=64,
#top_p=0.95,
#do_sample=True,
#temperature=0.3,
num_beams=1,
#use_cache=True,
logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)],
# num_beams=4,
# early_stopping=True,
do_sample=False,
use_cache=True
)
t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
t.start()
for new_text in streamer:
yield new_text.strip()
add_anonymous_usage_log("data/no-en-translation.jsonl")
return new_text.strip()
def switch_inputs(source, target, source_language, target_language):
return target, source, target_language, source_language
with gr.Blocks() as demo:
# with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
gr.Markdown("# Norwegian-English translation")
with gr.Row():
with gr.Column(scale=7, variant="panel"):
source_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[1], show_label=False
)
source = gr.Textbox(
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True
) # .style(container=False)
submit = gr.Button("Submit", variant="primary") # .style(full_width=True)
with gr.Column(scale=7, variant="panel"):
target_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[0], show_label=False
)
target = gr.Textbox(
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100
)
def update_state_after_user():
return {
source: gr.update(interactive=False),
submit: gr.update(interactive=False),
source_language: gr.update(interactive=False),
target_language: gr.update(interactive=False)
}
def update_state_after_return():
return {
source: gr.update(interactive=True),
submit: gr.update(interactive=True),
source_language: gr.update(interactive=True),
target_language: gr.update(interactive=True)
}
submit_event = source.submit(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
submit_click_event = submit.click(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
demo.queue(max_size=32, concurrency_count=2)
demo.launch()