import gradio as gr import torch import requests import io import huggingface_hub from transformers import BloomForCausalLM, BloomTokenizerFast import os repo_id = 'szzzzz/chatbot_bloom_560m' os.mkdir('./chatbot') path = huggingface_hub.snapshot_download( repo_id=repo_id, cache_dir='./chatbot',ignore_patterns = "*bin" ) url = huggingface_hub.file_download.hf_hub_url(repo_id, "pytorch_model.bin") tokenizer = BloomTokenizerFast.from_pretrained(path) state_dict = torch.load( io.BytesIO(requests.get(url).content), map_location=torch.device("cpu") ) model = BloomForCausalLM.from_pretrained( pretrained_model_name_or_path=None, state_dict=state_dict, config=f"{path}/config.json", ) max_length=1024 def generate(inputs: str) -> str: """generate content on inputs . Args: inputs (str): example :'Human: 你好 .\n \nAssistant: ' Returns: str: bot response example : '你好!我是你的ai助手!' """ input_text = tokenizer.bos_token + inputs input_ids = tokenizer.encode(input_text, return_tensors="pt") _, input_len = input_ids.shape if input_len >= max_length - 4: res = "对话超过字数限制,请重新开始." return res pred_ids = model.generate( input_ids, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, do_sample=True, temperature=0.6, top_p=0.8, max_new_tokens=max_length - input_len, repetition_penalty=1.2, ) pred = pred_ids[0][input_len:] res = tokenizer.decode(pred, skip_special_tokens=True) return res def add_text(history, text): history = history + [(text, None)] return history, "" def bot(history): prompt = "" for i, h in enumerate(history): prompt = prompt + "\nHuman: " + h[0] if i != len(history) - 1: prompt = prompt + "\nAssistant: " + h[1] else: prompt = prompt + "\nAssistant: " response = generate(prompt) history[-1][1] = response return history def regenerate(history): prompt = "" for i, h in enumerate(history): prompt = prompt + "\nHuman: " + h[0] if i != len(history) - 1: prompt = prompt + "\nAssistant: " + h[1] else: prompt = prompt + "\nAssistant: " response = generate(prompt) history[-1][1] = response return history with gr.Blocks() as demo: gr.Markdown("""chatbot of szzzzz.""") with gr.Tab("chatbot"): gr_chatbot = gr.Chatbot([]).style(height=300) txt = gr.Textbox( show_label=False, placeholder="Enter text and press enter", ).style(container=False) with gr.Row(): clear = gr.Button("Restart") regen = gr.Button("Regenerate response") # func txt.submit(add_text, [gr_chatbot, txt], [gr_chatbot, txt]).then( bot, gr_chatbot, gr_chatbot ) clear.click(lambda: None, None, gr_chatbot, queue=False) regen.click(regenerate, [gr_chatbot], [gr_chatbot]) demo.launch(server_name="0.0.0.0")