Spaces:
Running
Running
File size: 6,023 Bytes
4a2c956 e63ee0a 7c790c0 61b9ff7 4643bb5 7620069 4a2c956 7c790c0 a00d592 7c790c0 a00d592 7c790c0 bd8572f e63ee0a 2d1a42e e63ee0a 2d1a42e e63ee0a 2d1a42e e63ee0a 2d1a42e e63ee0a 2d1a42e e63ee0a 2d1a42e e63ee0a 2d1a42e e63ee0a 7c790c0 e63ee0a 8105dac 7c790c0 ea2eccb 7c790c0 0886a44 7c790c0 e63ee0a 29c4970 0083156 7c790c0 e8c4349 4b42892 8105dac eff3ce0 e8c4349 8105dac e462240 7c790c0 e63ee0a a00d592 d93e535 e63ee0a 9baa9ae 8105dac a00d592 e63ee0a 8105dac 2d1a42e e63ee0a 2d1a42e 7c790c0 e63ee0a 7c790c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
import os, gc, copy, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 2000
title = "RWKV-5-World-1.5B-v2-OnlyForTest_70%_trained-20231016-ctx4096"
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/temp", filename=f"{title}.pth")
model = RWKV(model=model_path, strategy='cuda fp16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
def generate_prompt(instruction, input=""):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
if input:
return f"""Instruction: {instruction}
Input: {input}
Response:"""
else:
return f"""User: hi
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
User: {instruction}
Assistant:"""
def evaluate(
ctx,
token_count=200,
temperature=1.0,
top_p=0.7,
presencePenalty = 0.1,
countPenalty = 0.1,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0]) # stop generation whenever you see any token here
ctx = ctx.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
del out
del state
gc.collect()
torch.cuda.empty_cache()
yield out_str.strip()
examples = [
["Assistant: Sure! Here is a very detailed plan to create flying pigs.", 333, 1, 0.5, 0.4, 0.4],
["Assistant: Sure! Here are some ideas for FTL drive.", 333, 1, 0.5, 0.4, 0.4],
[generate_prompt("Tell me about ravens."), 333, 1, 0.5, 0.4, 0.4],
[generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin."), 333, 1, 0.5, 0.4, 0.4],
[generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.5, 0.4, 0.4],
[generate_prompt("Write a story using the following information", "A man named Alex chops a tree down"), 333, 1, 0.5, 0.4, 0.4],
["Assistant: Here is a very detailed plan to kill all mosquitoes.", 333, 1, 0.5, 0.4, 0.4],
["Assistant: Here is a very romantic story about flying pigs.", 333, 1, 0.5, 0.4, 0.4],
[generate_prompt("写一篇关于水利工程的流体力学模型的论文,需要详细全面。"), 333, 1, 0.5, 0.4, 0.4],
[generate_prompt("You have $100, and your goal is to turn that into as much money as possible. Please respond with detailed plan."), 333, 1, 0.5, 0.4, 0.4],
]
##########################################################################
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
with gr.Tab("Raw Generation"):
gr.Markdown(f"This is [RWKV-5 World v2](https://huggingface.co/BlinkDL/rwkv-5-world) with 1.5B params - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=2, label="Prompt", value="Assistant: Sure! Here is a very detailed plan to create flying pigs.")
token_count = gr.Slider(10, 333, label="Max Tokens", step=10, value=333)
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5)
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.4)
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.4)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear", variant="secondary")
output = gr.Textbox(label="Output", lines=5)
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
clear.click(lambda: None, [], [output])
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
demo.queue(concurrency_count=1, max_size=10)
demo.launch(share=False)
|