Spaces:
Sleeping
Sleeping
File size: 8,104 Bytes
47bc5f9 420cc5c f61f538 420cc5c ecf390c 10dbfbb cc76a88 420cc5c f61f538 10dbfbb f61f538 398f6f3 f61f538 398f6f3 631d856 f61f538 398f6f3 f61f538 398f6f3 f61f538 398f6f3 f61f538 6ed95ae 398f6f3 6aa0b7b 398f6f3 6ed95ae d15dca0 420cc5c d15dca0 9c65c2d d15dca0 9c65c2d d15dca0 420cc5c d15dca0 9c65c2d d15dca0 a260fa0 c55a83f 500ab62 a260fa0 4bae60c a260fa0 500ab62 a260fa0 f6269f3 a260fa0 f8494a3 f61f538 398f6f3 3d7e4b1 398f6f3 631d856 398f6f3 631d856 398f6f3 a260fa0 398f6f3 631d856 398f6f3 a38f207 f8494a3 a38f207 374ff5c a260fa0 a38f207 8ae3341 398f6f3 a260fa0 d15dca0 a260fa0 f61f538 6ed95ae e292a98 95270f3 f61f538 6ed95ae f61f538 001b50d f61f538 6ed95ae f61f538 6ed95ae f61f538 001b50d f61f538 95270f3 f61f538 62ba09f f61f538 6ed95ae f61f538 6ed95ae f61f538 6ed95ae f61f538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
import json
import torch
import shutil
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
from transformers.generation import LogitsProcessor
import huggingface_hub
from huggingface_hub import Repository
from threading import Thread
import gradio as gr
print(f"Starting to load the model to memory")
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
pad_index = tokenizer.convert_tokens_to_ids("[PAD]")
eng_index = tokenizer.convert_tokens_to_ids(">>eng<<")
nob_index = tokenizer.convert_tokens_to_ids(">>nob<<")
nno_index = tokenizer.convert_tokens_to_ids(">>nno<<")
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SYSTEM: Running on {device}", flush=True)
model = model.to(device)
model.eval()
print(f"Sucessfully loaded the model to the memory")
LANGUAGES = [
"🇬🇧 English",
"🇳🇴 Norwegian (Bokmål)",
"🇳🇴 Norwegian (Nynorsk)"
]
LANGUAGE_IDS = {
"🇬🇧 English": eng_index,
"🇳🇴 Norwegian (Bokmål)": nob_index,
"🇳🇴 Norwegian (Nynorsk)": nno_index
}
STATS_REPO = "https://huggingface.co/datasets/ltg/usage_statistics"
HF_TOKEN = os.environ.get("HF_TOKEN")
dataset = Repository(
local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN
)
# log the timestamp of the query
def add_anonymous_usage_log(path):
global dataset
try:
dataset.git_pull()
with open(path, "a") as f:
line = json.dumps(str(datetime.now()), ensure_ascii=False)
f.write(f"{line}\n")
dataset.push_to_hub(blocking=False)
except:
shutil.rmtree("data")
dataset = Repository(
local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN
)
with open(path, "a") as f:
line = json.dumps(str(datetime.now()), ensure_ascii=False)
f.write(f"{line}\n")
dataset.push_to_hub(blocking=False)
class BatchStreamer(TextIteratorStreamer):
def put(self, value):
print(value.shape)
#if value.size(0) == 1:
# return super().put(value)
if len(self.token_cache) == 0:
self.token_cache = [[] for _ in range(value.size(0))]
value = value.tolist()
# Add the new token to the cache and decodes the entire thing.
for c, v in zip(self.token_cache, value):
c += [v] if isinstance(v, int) else v
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
text = '\n'.join(paragraphs)
self.on_finalized_text(text)
def end(self):
if len(self.token_cache) > 0:
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
printable_text = '\n'.join(paragraphs)
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float, model):
last_bias = model.classifier.nonlinearity[-1].bias.data
last_bias = torch.nn.functional.log_softmax(last_bias)
self.penalty = penalty * (last_bias - last_bias.max())
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids)
scores.scatter_(1, input_ids, penalized_score)
return scores
def translate(source, source_language, target_language):
if source_language == target_language:
yield source.strip()
return source.strip()
source = [s.strip() for s in source.split('\n')]
source_subwords = tokenizer(source).input_ids
source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords]
source_subwords = [torch.tensor(s) for s in source_subwords]
source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index)
source_subwords = source_subwords[:, :512].to(device)
streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True)
def generate(model, **kwargs):
with torch.inference_mode():
with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16):
return model.generate(**kwargs)
generate_kwargs = dict(
streamer=streamer,
input_ids=source_subwords,
attention_mask=(source_subwords != pad_index).long(),
max_new_tokens = 512-1,
#top_k=64,
#top_p=0.95,
#do_sample=True,
#temperature=0.3,
num_beams=1,
#use_cache=True,
logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)],
# num_beams=4,
# early_stopping=True,
do_sample=False,
use_cache=True
)
t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
t.start()
for new_text in streamer:
yield new_text.strip()
add_anonymous_usage_log("data/no-en-translation.jsonl")
return new_text.strip()
def switch_inputs(source, target, source_language, target_language):
return target, source, target_language, source_language
with gr.Blocks() as demo:
# with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
gr.Markdown("# Norwegian-English translation")
with gr.Row():
with gr.Column(scale=7, variant="panel"):
source_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[1], show_label=False
)
source = gr.Textbox(
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True
) # .style(container=False)
submit = gr.Button("Submit", variant="primary") # .style(full_width=True)
with gr.Column(scale=7, variant="panel"):
target_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[0], show_label=False
)
target = gr.Textbox(
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100
)
def update_state_after_user():
return {
source: gr.update(interactive=False),
submit: gr.update(interactive=False),
source_language: gr.update(interactive=False),
target_language: gr.update(interactive=False)
}
def update_state_after_return():
return {
source: gr.update(interactive=True),
submit: gr.update(interactive=True),
source_language: gr.update(interactive=True),
target_language: gr.update(interactive=True)
}
submit_event = source.submit(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
submit_click_event = submit.click(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
demo.queue(max_size=32, concurrency_count=2)
demo.launch()
|