larryvrh's picture
Update chat_webui.py
85f025c
raw
history blame contribute delete
No virus
3.65 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from peft import PeftModel
import torch
from threading import Thread
model_path = ('TigerResearch/tigerbot-13b-chat', None)
lora_path = 'larryvrh/tigerbot-13b-chat-sharegpt-lora'
tokenizer = AutoTokenizer.from_pretrained(model_path[0])
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(model_path[0], revision=model_path[1],
device_map="auto",
quantization_config = quant_config, #load_in_8bit=True,
)
model = PeftModel.from_pretrained(model, lora_path)
model.eval()
def predict(input, chatbot, max_length, top_p, temperature, rep_penalty, retry):
if retry and len(chatbot) == 0:
yield []
return
elif retry:
input = chatbot[-1][0]
chatbot = chatbot[:-1]
chatbot.append((input, ""))
prompt = '<s>' + ''.join([f'\n\n### Instruction:\n{r[0]}\n\n### Response:\n{r[1]}' for r in chatbot])
print('prompt:', repr(prompt))
model_inputs = tokenizer([prompt], return_tensors="pt", truncation=True, max_length=max_length-500).to('cuda')
streamer = TextIteratorStreamer(tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=500,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=rep_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for response in streamer:
chatbot[-1] = (chatbot[-1][0], chatbot[-1][1] + response)
yield chatbot
def reset_user_input():
return gr.update(value='')
def reset_state():
return []
css='''
.contain {max-width:50}
#chatbot {min-height:500px}
'''
with gr.Blocks(css=css) as demo:
gr.HTML('<h1 align="center">TigerBot</h1>')
chatbot = gr.Chatbot(elem_id='chatbot')
with gr.Column():
user_input = gr.Textbox(show_label=False, placeholder="输入", lines=1).style(container=False)
with gr.Row():
submitBtn = gr.Button("发送", variant="primary")
retryBtn = gr.Button("重试")
cancelBtn = gr.Button('撤销')
emptyBtn = gr.Button("清空")
with gr.Row():
max_length = gr.Slider(0, 4096, value=2048, step=1, label="Context Length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True)
temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True)
rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True)
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)],
[chatbot], show_progress=False)
submitBtn.click(reset_user_input, [], [user_input], show_progress=False)
retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)],
[chatbot], show_progress=False)
cancelBtn.click(lambda m:m[:-1], [chatbot], [chatbot], show_progress=False)
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=False)
demo.queue().launch(share=False, inbrowser=True)