|
import gradio as gr |
|
import torch |
|
import requests |
|
import io |
|
import huggingface_hub |
|
from transformers import BloomForCausalLM, BloomTokenizerFast |
|
import os |
|
|
|
repo_id = 'szzzzz/chatbot_bloom_560m' |
|
os.mkdir('./chatbot') |
|
path = huggingface_hub.snapshot_download( |
|
repo_id=repo_id, cache_dir='./chatbot',ignore_patterns = "*bin" |
|
) |
|
url = huggingface_hub.file_download.hf_hub_url(repo_id, "pytorch_model.bin") |
|
tokenizer = BloomTokenizerFast.from_pretrained(path) |
|
state_dict = torch.load( |
|
io.BytesIO(requests.get(url).content), map_location=torch.device("cpu") |
|
) |
|
model = BloomForCausalLM.from_pretrained( |
|
pretrained_model_name_or_path=None, |
|
state_dict=state_dict, |
|
config=f"{path}/config.json", |
|
) |
|
max_length=1024 |
|
|
|
|
|
def generate(inputs: str) -> str: |
|
"""generate content on inputs . |
|
|
|
Args: |
|
inputs (str): |
|
example :'Human: 你好 .\n \nAssistant: ' |
|
|
|
Returns: |
|
str: |
|
bot response |
|
example : '你好!我是你的ai助手!' |
|
|
|
""" |
|
input_text = tokenizer.bos_token + inputs |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt") |
|
_, input_len = input_ids.shape |
|
if input_len >= max_length - 4: |
|
res = "对话超过字数限制,请重新开始." |
|
return res |
|
pred_ids = model.generate( |
|
input_ids, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
bos_token_id=tokenizer.bos_token_id, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.8, |
|
max_new_tokens=max_length - input_len, |
|
repetition_penalty=1.2, |
|
) |
|
pred = pred_ids[0][input_len:] |
|
res = tokenizer.decode(pred, skip_special_tokens=True) |
|
return res |
|
|
|
|
|
def add_text(history, text): |
|
history = history + [(text, None)] |
|
return history, "" |
|
|
|
|
|
def bot(history): |
|
prompt = "" |
|
for i, h in enumerate(history): |
|
prompt = prompt + "\nHuman: " + h[0] |
|
if i != len(history) - 1: |
|
prompt = prompt + "\nAssistant: " + h[1] |
|
else: |
|
prompt = prompt + "\nAssistant: " |
|
|
|
response = generate(prompt) |
|
history[-1][1] = response |
|
return history |
|
|
|
def regenerate(history): |
|
prompt = "" |
|
for i, h in enumerate(history): |
|
prompt = prompt + "\nHuman: " + h[0] |
|
if i != len(history) - 1: |
|
prompt = prompt + "\nAssistant: " + h[1] |
|
else: |
|
prompt = prompt + "\nAssistant: " |
|
|
|
response = generate(prompt) |
|
history[-1][1] = response |
|
return history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("""chatbot of szzzzz.""") |
|
|
|
with gr.Tab("chatbot"): |
|
gr_chatbot = gr.Chatbot([]).style(height=300) |
|
|
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter text and press enter", |
|
).style(container=False) |
|
with gr.Row(): |
|
clear = gr.Button("Restart") |
|
regen = gr.Button("Regenerate response") |
|
|
|
|
|
txt.submit(add_text, [gr_chatbot, txt], [gr_chatbot, txt]).then( |
|
bot, gr_chatbot, gr_chatbot |
|
) |
|
|
|
clear.click(lambda: None, None, gr_chatbot, queue=False) |
|
regen.click(regenerate, [gr_chatbot], [gr_chatbot]) |
|
|
|
|
|
demo.launch(server_name="0.0.0.0") |
|
|