chatbot / app.py
szzzzz's picture
Update app.py
a030089
import gradio as gr
import torch
import requests
import io
import huggingface_hub
from transformers import BloomForCausalLM, BloomTokenizerFast
import os
repo_id = 'szzzzz/chatbot_bloom_560m'
os.mkdir('./chatbot')
path = huggingface_hub.snapshot_download(
repo_id=repo_id, cache_dir='./chatbot',ignore_patterns = "*bin"
)
url = huggingface_hub.file_download.hf_hub_url(repo_id, "pytorch_model.bin")
tokenizer = BloomTokenizerFast.from_pretrained(path)
state_dict = torch.load(
io.BytesIO(requests.get(url).content), map_location=torch.device("cpu")
)
model = BloomForCausalLM.from_pretrained(
pretrained_model_name_or_path=None,
state_dict=state_dict,
config=f"{path}/config.json",
)
max_length=1024
def generate(inputs: str) -> str:
"""generate content on inputs .
Args:
inputs (str):
example :'Human: 你好 .\n \nAssistant: '
Returns:
str:
bot response
example : '你好!我是你的ai助手!'
"""
input_text = tokenizer.bos_token + inputs
input_ids = tokenizer.encode(input_text, return_tensors="pt")
_, input_len = input_ids.shape
if input_len >= max_length - 4:
res = "对话超过字数限制,请重新开始."
return res
pred_ids = model.generate(
input_ids,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
do_sample=True,
temperature=0.6,
top_p=0.8,
max_new_tokens=max_length - input_len,
repetition_penalty=1.2,
)
pred = pred_ids[0][input_len:]
res = tokenizer.decode(pred, skip_special_tokens=True)
return res
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def bot(history):
prompt = ""
for i, h in enumerate(history):
prompt = prompt + "\nHuman: " + h[0]
if i != len(history) - 1:
prompt = prompt + "\nAssistant: " + h[1]
else:
prompt = prompt + "\nAssistant: "
response = generate(prompt)
history[-1][1] = response
return history
def regenerate(history):
prompt = ""
for i, h in enumerate(history):
prompt = prompt + "\nHuman: " + h[0]
if i != len(history) - 1:
prompt = prompt + "\nAssistant: " + h[1]
else:
prompt = prompt + "\nAssistant: "
response = generate(prompt)
history[-1][1] = response
return history
with gr.Blocks() as demo:
gr.Markdown("""chatbot of szzzzz.""")
with gr.Tab("chatbot"):
gr_chatbot = gr.Chatbot([]).style(height=300)
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
).style(container=False)
with gr.Row():
clear = gr.Button("Restart")
regen = gr.Button("Regenerate response")
# func
txt.submit(add_text, [gr_chatbot, txt], [gr_chatbot, txt]).then(
bot, gr_chatbot, gr_chatbot
)
clear.click(lambda: None, None, gr_chatbot, queue=False)
regen.click(regenerate, [gr_chatbot], [gr_chatbot])
demo.launch(server_name="0.0.0.0")