File size: 6,065 Bytes
a8edea3 785a54b a8edea3 fb9114d a8edea3 785a54b a8edea3 785a54b a8edea3 785a54b a8edea3 785a54b dfb402a 785a54b a8edea3 785a54b a8edea3 785a54b 2959d62 a8edea3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import rwkv_rs
import numpy as np
import huggingface_hub
import tokenizers
import gradio as gr
model_path = "./rnn.safetensors"
if not os.path.exists(model_path):
model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors")
assert model_path is not None
model = rwkv_rs.Rwkv(model_path)
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
GT = [
gr.Button.update(visible=False),
gr.Button.update(visible=True),
]
GF = [
gr.Button.update(visible=True),
gr.Button.update(visible=False),
]
def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
try:
state = rwkv_rs.State(model)
text = inpt
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(inpt).ids
yield (None, gr.Text.update(visible=False))
# yield ("Preproc...", gr.Text.update(visible=False))
# logits = model.forward(tokens, state)
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
# finally:
# return (None, None)
def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert):
try:
if inpt.count("<|INSERT|>") != 1:
yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True))
return
state = rwkv_rs.State(model)
text, end = inpt.split("<|INSERT|>")
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(text).ids
tokens_end = tokenizer.encode(end).ids
tokens_i = tokens_end[:num_tokens_insert]
ins = [0]*len(tokens_i)
yield (None, gr.Text.update(visible=False))
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
ins = ins[1:] + [token]
if ins == tokens_i:
tokens += tokens_end[num_tokens_insert:]
i = max_tokens - 1 # to break earlier...
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
def generator_wrap(l, fn):
def wrap(*args):
last_i = list([None] * l)
try:
for i in fn(*args):
last_i = list(i)
yield last_i + GT
finally:
yield last_i + GF
return wrap
with gr.Blocks() as app:
gr.Markdown(f"Running on `{model_path}`")
error_box = gr.Text(label="Error", visible=False)
with gr.Tab("Complete"):
with gr.Row():
inpt = gr.TextArea(label="Input")
out = gr.TextArea(label="Output")
complete = gr.Button("Complete", variant="primary")
c_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Insert"):
gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated")
with gr.Row():
inpt_i = gr.TextArea(label="Input")
out_i = gr.TextArea(label="Output")
num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
insert = gr.Button("Insert", variant="primary")
i_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Column():
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1)
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop])
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
app.queue(concurrency_count=2)
app.launch() |