File size: 3,234 Bytes
1af67d0
54d5818
 
a030089
a27b49c
54d5818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af67d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54d5818
1af67d0
 
 
 
 
 
 
 
 
 
 
 
54d5818
1af67d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import torch
import requests
import io
import huggingface_hub
from transformers import BloomForCausalLM, BloomTokenizerFast
import os 

repo_id = 'szzzzz/chatbot_bloom_560m'
os.mkdir('./chatbot')
path = huggingface_hub.snapshot_download(
                repo_id=repo_id, cache_dir='./chatbot',ignore_patterns = "*bin"
            )
url = huggingface_hub.file_download.hf_hub_url(repo_id, "pytorch_model.bin")
tokenizer = BloomTokenizerFast.from_pretrained(path)
state_dict = torch.load(
    io.BytesIO(requests.get(url).content), map_location=torch.device("cpu")
)
model = BloomForCausalLM.from_pretrained(
    pretrained_model_name_or_path=None,
    state_dict=state_dict,
    config=f"{path}/config.json",
)
max_length=1024


def generate(inputs: str) -> str:
    """generate content on inputs .

    Args:
        inputs (str):
            example :'Human: 你好 .\n \nAssistant: '

    Returns:
        str:
            bot response
            example : '你好!我是你的ai助手!'

    """
    input_text = tokenizer.bos_token + inputs
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    _, input_len = input_ids.shape
    if input_len >= max_length - 4:
        res = "对话超过字数限制,请重新开始."
        return res
    pred_ids = model.generate(
        input_ids,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.8,
        max_new_tokens=max_length - input_len,
        repetition_penalty=1.2,
    )
    pred = pred_ids[0][input_len:]
    res = tokenizer.decode(pred, skip_special_tokens=True)
    return res


def add_text(history, text):
    history = history + [(text, None)]
    return history, ""


def bot(history):
    prompt = ""
    for i, h in enumerate(history):
        prompt = prompt + "\nHuman: " + h[0]
        if i != len(history) - 1:
            prompt = prompt + "\nAssistant: " + h[1]
        else:
            prompt = prompt + "\nAssistant: "

    response = generate(prompt)
    history[-1][1] = response
    return history

def regenerate(history):
    prompt = ""
    for i, h in enumerate(history):
        prompt = prompt + "\nHuman: " + h[0]
        if i != len(history) - 1:
            prompt = prompt + "\nAssistant: " + h[1]
        else:
            prompt = prompt + "\nAssistant: "

    response = generate(prompt)
    history[-1][1] = response
    return history


with gr.Blocks() as demo:
    gr.Markdown("""chatbot of szzzzz.""")

    with gr.Tab("chatbot"):
        gr_chatbot = gr.Chatbot([]).style(height=300)

        txt = gr.Textbox(
            show_label=False,
            placeholder="Enter text and press enter",
        ).style(container=False)
        with gr.Row():
            clear = gr.Button("Restart")
            regen = gr.Button("Regenerate response")

        # func
        txt.submit(add_text, [gr_chatbot, txt], [gr_chatbot, txt]).then(
            bot, gr_chatbot, gr_chatbot
        )

        clear.click(lambda: None, None, gr_chatbot, queue=False)
        regen.click(regenerate, [gr_chatbot], [gr_chatbot])


demo.launch(server_name="0.0.0.0")