import os import rwkv_rs import numpy as np import huggingface_hub import tokenizers import gradio as gr model_path = "./rnn.safetensors" if not os.path.exists(model_path): model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors") assert model_path is not None model = rwkv_rs.Rwkv(model_path) tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b") GT = [ gr.Button.update(visible=False), gr.Button.update(visible=True), ] GF = [ gr.Button.update(visible=True), gr.Button.update(visible=False), ] def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p): try: state = rwkv_rs.State(model) text = inpt counts = [0]*tokenizer.get_vocab_size() tokens = tokenizer.encode(inpt).ids # yield ("Preproc...", gr.Text.update(visible=False)) # logits = model.forward(tokens, state) for i in range(len(tokens) - 1): model.forward_token_preproc(tokens[i], state) yield (tokenizer.decode(tokens[:i + 1]), None) logits = model.forward_token(tokens[-1], state) yield (text, None) max_tokens = int(max_tokens) for i in range(max_tokens): if i < min_tokens: logits[0] = -100 for i in range(len(counts)): logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p) token = np.argmax(logits) counts[token] += 1 if token == 0: break if i == max_tokens - 1: break tokens += [token] text = tokenizer.decode(tokens) yield (text, None) logits = model.forward_token(token, state) yield (text, None) except Exception as e: print(e) yield ("Error...", gr.Text.update(value=str(e), visible=True)) # finally: # return (None, None) def generator_wrap(l, fn): def wrap(*args): last_i = list([None] * l) try: for i in fn(*args): last_i = list(i) yield last_i + GT finally: yield last_i + GF return wrap with gr.Blocks() as app: gr.Markdown(f"Running on `{model_path}`") error_box = gr.Text(label="Error", visible=False) with gr.Tab("Complete"): with gr.Row(): inpt = gr.TextArea(label="Input") out = gr.TextArea(label="Output") complete = gr.Button("Complete", variant="primary") c_stop = gr.Button("Stop", variant="stop", visible=False) with gr.Tab("Insert (WIP)"): gr.Markdown("WIP, use `<|INSERT|>` to indicate a place to replace") with gr.Row(): inpt_i = gr.TextArea(label="Input") out_i = gr.TextArea(label="Output") insert = gr.Button("Insert") with gr.Column(): max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767) min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1) alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01) alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01) G = [complete, c_stop] c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box] + G) c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False) app.queue() app.launch()