File size: 3,631 Bytes
a8edea3 fb9114d a8edea3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import rwkv_rs
import numpy as np
import huggingface_hub
import tokenizers
import gradio as gr
model_path = "./rnn.safetensors"
if not os.path.exists(model_path):
model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors")
assert model_path is not None
model = rwkv_rs.Rwkv(model_path)
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
GT = [
gr.Button.update(visible=False),
gr.Button.update(visible=True),
]
GF = [
gr.Button.update(visible=True),
gr.Button.update(visible=False),
]
def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
try:
state = rwkv_rs.State(model)
text = inpt
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(inpt).ids
# yield ("Preproc...", gr.Text.update(visible=False))
# logits = model.forward(tokens, state)
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
if i == max_tokens - 1:
break
tokens += [token]
text = tokenizer.decode(tokens)
yield (text, None)
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
# finally:
# return (None, None)
def generator_wrap(l, fn):
def wrap(*args):
last_i = list([None] * l)
try:
for i in fn(*args):
last_i = list(i)
yield last_i + GT
finally:
yield last_i + GF
return wrap
with gr.Blocks() as app:
gr.Markdown(f"Running on `{model_path}`")
error_box = gr.Text(label="Error", visible=False)
with gr.Tab("Complete"):
with gr.Row():
inpt = gr.TextArea(label="Input")
out = gr.TextArea(label="Output")
complete = gr.Button("Complete", variant="primary")
c_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Insert (WIP)"):
gr.Markdown("WIP, use `<|INSERT|>` to indicate a place to replace")
with gr.Row():
inpt_i = gr.TextArea(label="Input")
out_i = gr.TextArea(label="Output")
insert = gr.Button("Insert")
with gr.Column():
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1)
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
G = [complete, c_stop]
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box] + G)
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
app.queue()
app.launch() |