Spaces:
Sleeping
Sleeping
File size: 7,063 Bytes
f61f538 10dbfbb cc76a88 f61f538 10dbfbb f61f538 398f6f3 f61f538 398f6f3 631d856 f61f538 398f6f3 f61f538 398f6f3 f61f538 398f6f3 f61f538 6ed95ae 398f6f3 6aa0b7b 398f6f3 6ed95ae a260fa0 c55a83f 500ab62 a260fa0 4bae60c a260fa0 500ab62 a260fa0 f6269f3 a260fa0 f8494a3 f61f538 398f6f3 3d7e4b1 398f6f3 631d856 398f6f3 631d856 398f6f3 a260fa0 398f6f3 631d856 398f6f3 a38f207 f8494a3 a38f207 374ff5c a260fa0 a38f207 8ae3341 398f6f3 a260fa0 f61f538 6ed95ae e292a98 95270f3 f61f538 6ed95ae f61f538 001b50d f61f538 6ed95ae f61f538 6ed95ae f61f538 001b50d f61f538 95270f3 f61f538 62ba09f f61f538 6ed95ae f61f538 6ed95ae f61f538 6ed95ae f61f538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
from transformers.generation import LogitsProcessor
from threading import Thread
import gradio as gr
print(f"Starting to load the model to memory")
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
pad_index = tokenizer.convert_tokens_to_ids("[PAD]")
eng_index = tokenizer.convert_tokens_to_ids(">>eng<<")
nob_index = tokenizer.convert_tokens_to_ids(">>nob<<")
nno_index = tokenizer.convert_tokens_to_ids(">>nno<<")
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SYSTEM: Running on {device}", flush=True)
model = model.to(device)
model.eval()
print(f"Sucessfully loaded the model to the memory")
LANGUAGES = [
"🇬🇧 English",
"🇳🇴 Norwegian (Bokmål)",
"🇳🇴 Norwegian (Nynorsk)"
]
LANGUAGE_IDS = {
"🇬🇧 English": eng_index,
"🇳🇴 Norwegian (Bokmål)": nob_index,
"🇳🇴 Norwegian (Nynorsk)": nno_index
}
class BatchStreamer(TextIteratorStreamer):
def put(self, value):
print(value.shape)
#if value.size(0) == 1:
# return super().put(value)
if len(self.token_cache) == 0:
self.token_cache = [[] for _ in range(value.size(0))]
value = value.tolist()
# Add the new token to the cache and decodes the entire thing.
for c, v in zip(self.token_cache, value):
c += [v] if isinstance(v, int) else v
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
text = '\n'.join(paragraphs)
self.on_finalized_text(text)
def end(self):
if len(self.token_cache) > 0:
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
printable_text = '\n'.join(paragraphs)
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float, model):
last_bias = model.classifier.nonlinearity[-1].bias.data
last_bias = torch.nn.functional.log_softmax(last_bias)
self.penalty = penalty * (last_bias - last_bias.max())
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids)
scores.scatter_(1, input_ids, penalized_score)
return scores
def translate(source, source_language, target_language):
if source_language == target_language:
yield source.strip()
return source.strip()
source = [s.strip() for s in source.split('\n')]
source_subwords = tokenizer(source).input_ids
source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords]
source_subwords = [torch.tensor(s) for s in source_subwords]
source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index)
source_subwords = source_subwords[:, :512].to(device)
streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True)
def generate(model, **kwargs):
with torch.inference_mode():
with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16):
return model.generate(**kwargs)
generate_kwargs = dict(
streamer=streamer,
input_ids=source_subwords,
attention_mask=(source_subwords != pad_index).long(),
max_new_tokens = 512-1,
#top_k=64,
#top_p=0.95,
#do_sample=True,
#temperature=0.3,
num_beams=1,
#use_cache=True,
logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)],
# num_beams=4,
# early_stopping=True,
do_sample=False,
use_cache=True
)
t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
t.start()
for new_text in streamer:
yield new_text.strip()
return new_text.strip()
def switch_inputs(source, target, source_language, target_language):
return target, source, target_language, source_language
with gr.Blocks() as demo:
# with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
gr.Markdown("# Norwegian-English translation")
with gr.Row():
with gr.Column(scale=7, variant="panel"):
source_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[1], show_label=False
)
source = gr.Textbox(
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True
) # .style(container=False)
submit = gr.Button("Submit", variant="primary") # .style(full_width=True)
with gr.Column(scale=7, variant="panel"):
target_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[0], show_label=False
)
target = gr.Textbox(
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100
)
def update_state_after_user():
return {
source: gr.update(interactive=False),
submit: gr.update(interactive=False),
source_language: gr.update(interactive=False),
target_language: gr.update(interactive=False)
}
def update_state_after_return():
return {
source: gr.update(interactive=True),
submit: gr.update(interactive=True),
source_language: gr.update(interactive=True),
target_language: gr.update(interactive=True)
}
submit_event = source.submit(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
submit_click_event = submit.click(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
demo.queue(max_size=32, concurrency_count=2)
demo.launch()
|