File size: 2,864 Bytes
1c776f7
5e66ec0
dd48380
 
 
1c776f7
dd48380
1c776f7
 
dd48380
1c776f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd48380
5e66ec0
 
 
 
 
 
 
 
1c776f7
5e66ec0
1c776f7
dd48380
5e66ec0
 
 
1c776f7
 
5e66ec0
1c776f7
 
dd48380
1c776f7
 
dd48380
 
fbcf846
dd48380
 
 
1c776f7
dd48380
 
1c776f7
 
 
 
 
 
dd48380
 
73bf78a
5e66ec0
 
 
 
 
1c776f7
5e66ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import urllib
import chatglm_cpp
import gradio as gr

import requests

endpoint_url = os.getenv('ENDPOINT_URL')
personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')

turn_breaker = os.getenv('TURN_BREAKER')
system_symbol = os.getenv('SYSTEM_SYMBOL')
user_symbol = os.getenv('USER_SYMBOL')
assistant_symbol = os.getenv('ASSISTANT_SYMBOL')

headers = {
	"Accept" : "application/json",
    "Authorization": f"Bearer {personal_secret_token}",
	"Content-Type": "application/json" 
}

def query(payload):
	# response = requests.post(endpoint_url, headers=headers, json=payload)
	# return response.json()
    return payload['inputs']

# output = query({
# 	"inputs": "你啥比",
# 	"parameters": {
# 		"max_new_tokens": 150
# 	}
# })
# system_message = chatglm_cpp.ChatMessage(role="system", content="请你现在扮演一个软件工程师,名字叫做贺英旭。你需要以这个身份和朋友们对话。")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_new_tokens,
    temperature,
    top_p,
):
    # messages = [chatglm_cpp.ChatMessage(role="system", content=system_message)]
    # messages = [{"role": "system", "content": system_message}]
    all_messages = [system_message]

    for val in history:
        if val[0]:
            # messages.append({"role": "user", "content": val[0]})
            all_messages.append(user_symbol+val[0])
            # messages.append(chatglm_cpp.ChatMessage(role="user", content=val[0]))
        if val[1]:
            all_messages.append(assistant_symbol+val[1])
            # messages.append(chatglm_cpp.ChatMessage(role="assistant", content=val[1]))

    # messages.append(chatglm_cpp.ChatMessage(role="user", content=message))
    all_messages.append(user_symbol+message)

    generation_kwargs = dict(
        max_new_tokens=max_new_tokens,
        do_sample=temperature > 0,
        top_p=top_p,
        temperature=temperature,
        # stream=True,
    )

    response = query({
        "inputs": turn_breaker.join(all_messages),
        "parameters": generation_kwargs
    })

    return response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="请你现在扮演一个软件工程师,名字叫做贺英旭。你需要以这个身份和朋友们对话。", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()