File size: 928 Bytes
561ca81
ab85003
561ca81
2432281
ab85003
b74218d
2432281
b74218d
561ca81
 
 
 
 
be6ae85
588129f
 
d903d0d
588129f
561ca81
 
 
 
588129f
561ca81
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from rwkvstic.load import RWKV
import torch
model = RWKV(
    "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
    "pytorch(cpu/gpu)",
    runtimedtype=torch.float32,
    useGPU=torch.cuda.is_available(),
    dtype=torch.float32
)
import gradio as gr


def predict(input, history=None):
    model.setState(history[1])
    model.loadContext(newctx=f"Prompt: {input}\n\nExpert Long Detailed Response: ")
    r = model.forward(number=100,stopStrings=["\n\nPrompt"])
    rr = [(input,r["output"])]
    return [*history[0],*rr], [[*history[0],*rr],r["state"]]

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    state = model.emptyState
    state = gr.State([[],state])
    with gr.Row():
        txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)

    txt.submit(predict, [txt, state], [chatbot, state])

demo.launch()