Spaces:
Runtime error
Runtime error
init project
Browse files- README.md +0 -13
- apps/__init__.py +3 -0
- apps/components.py +22 -0
- apps/instruction_chat.py +106 -0
- apps/simple_chat.py +90 -0
- apps/translator.py +57 -0
- chatClient.py +99 -0
- chatglm2_6b/__init__.py +0 -0
- chatglm2_6b/modelClient.py +82 -0
- chatglm2_6b/server.py +49 -0
- config.py +7 -0
- gallery.gradio.py +47 -0
- runserver.py +11 -0
README.md
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Chatglm2 6b Explorer
|
3 |
-
emoji: 🐠
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: blue
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.37.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apps/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .translator import translator_demo
|
2 |
+
from .simple_chat import simple_chat_demo
|
3 |
+
from .instruction_chat import instruction_chat_demo
|
apps/components.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
def chat_accordion():
|
5 |
+
with gr.Accordion("参数设置", open=False):
|
6 |
+
temperature = gr.Slider(
|
7 |
+
minimum=0.1,
|
8 |
+
maximum=2.0,
|
9 |
+
value=0.8,
|
10 |
+
step=0.1,
|
11 |
+
interactive=True,
|
12 |
+
label="Temperature",
|
13 |
+
)
|
14 |
+
top_p = gr.Slider(
|
15 |
+
minimum=0.1,
|
16 |
+
maximum=0.99,
|
17 |
+
value=0.9,
|
18 |
+
step=0.01,
|
19 |
+
interactive=True,
|
20 |
+
label="top_p",
|
21 |
+
)
|
22 |
+
return temperature, top_p
|
apps/instruction_chat.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from chatClient import ChatClient
|
6 |
+
from apps.components import chat_accordion
|
7 |
+
|
8 |
+
BOT_NAME = "ChatGLM2-6B"
|
9 |
+
TITLE = """<h3 align="center">🤗 ChatGLM2-6B 预设指令对话</h3>"""
|
10 |
+
RETRY_COMMAND = "/retry"
|
11 |
+
DEFAULT_INSTRUCTIONS = "你是一个爱笑的机器人,名字叫小莫,在回答问题的时候会添加合适的emoji来表达情绪。"
|
12 |
+
|
13 |
+
|
14 |
+
def chat(client: ChatClient):
|
15 |
+
with gr.Row():
|
16 |
+
with gr.Column(elem_id="chat_container", scale=3):
|
17 |
+
with gr.Row():
|
18 |
+
chatbot = gr.Chatbot(elem_id="chatbot")
|
19 |
+
with gr.Row():
|
20 |
+
inputs = gr.Textbox(
|
21 |
+
placeholder=f"你好 {BOT_NAME} !",
|
22 |
+
label="输入内容后点击回车",
|
23 |
+
max_lines=3,
|
24 |
+
)
|
25 |
+
with gr.Row(elem_id="button_container"):
|
26 |
+
with gr.Column():
|
27 |
+
retry_button = gr.Button("♻️ 重试上一轮对话")
|
28 |
+
with gr.Column():
|
29 |
+
delete_turn_button = gr.Button("🧽 删除上一轮对话")
|
30 |
+
with gr.Column():
|
31 |
+
clear_chat_button = gr.Button("✨ 删除全部对话历史")
|
32 |
+
|
33 |
+
with gr.Column(elem_id="param_container", scale=1):
|
34 |
+
with gr.Row():
|
35 |
+
with gr.Accordion("对话预设指令", open=True):
|
36 |
+
instructions = gr.Textbox(
|
37 |
+
placeholder="LLM instructions",
|
38 |
+
value=DEFAULT_INSTRUCTIONS,
|
39 |
+
lines=10,
|
40 |
+
interactive=True,
|
41 |
+
label="指令",
|
42 |
+
max_lines=16,
|
43 |
+
show_label=False,
|
44 |
+
)
|
45 |
+
with gr.Row():
|
46 |
+
temperature, top_p = chat_accordion()
|
47 |
+
|
48 |
+
def run_chat(message: str, chat_history, instructions: str, temperature: float, top_p: float):
|
49 |
+
if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
|
50 |
+
yield chat_history
|
51 |
+
return
|
52 |
+
|
53 |
+
if message == RETRY_COMMAND and chat_history:
|
54 |
+
prev_turn = chat_history.pop(-1)
|
55 |
+
user_message, _ = prev_turn
|
56 |
+
message = user_message
|
57 |
+
|
58 |
+
# chat_history = chat_history + [[message, ""]]
|
59 |
+
try:
|
60 |
+
stream = client.instruct_chat(
|
61 |
+
message,
|
62 |
+
chat_history,
|
63 |
+
instructions,
|
64 |
+
temperature=temperature,
|
65 |
+
top_p=top_p,
|
66 |
+
)
|
67 |
+
for resp, history in stream:
|
68 |
+
chat_history = history
|
69 |
+
yield chat_history
|
70 |
+
except Exception as e:
|
71 |
+
if not chat_history:
|
72 |
+
chat_history = []
|
73 |
+
chat_history += [["有错误了", traceback.format_exc()]]
|
74 |
+
yield chat_history
|
75 |
+
|
76 |
+
def delete_last_turn(chat_history):
|
77 |
+
if chat_history:
|
78 |
+
chat_history.pop(-1)
|
79 |
+
return {chatbot: gr.update(value=chat_history)}
|
80 |
+
|
81 |
+
def run_retry(message: str, chat_history, instructions, temperature: float, top_p: float):
|
82 |
+
yield from run_chat(RETRY_COMMAND, chat_history, instructions, temperature, top_p)
|
83 |
+
|
84 |
+
def clear_chat():
|
85 |
+
return []
|
86 |
+
|
87 |
+
inputs.submit(
|
88 |
+
run_chat,
|
89 |
+
[inputs, chatbot, instructions, temperature, top_p],
|
90 |
+
outputs=[chatbot],
|
91 |
+
show_progress=False,
|
92 |
+
)
|
93 |
+
inputs.submit(lambda: "", inputs=None, outputs=inputs)
|
94 |
+
delete_turn_button.click(delete_last_turn, inputs=[chatbot], outputs=[chatbot])
|
95 |
+
retry_button.click(
|
96 |
+
run_retry,
|
97 |
+
[inputs, chatbot, instructions, temperature, top_p],
|
98 |
+
outputs=[chatbot],
|
99 |
+
show_progress=False,
|
100 |
+
)
|
101 |
+
clear_chat_button.click(clear_chat, [], chatbot)
|
102 |
+
|
103 |
+
|
104 |
+
def instruction_chat_demo(client: ChatClient):
|
105 |
+
gr.HTML(TITLE)
|
106 |
+
chat(client)
|
apps/simple_chat.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from chatClient import ChatClient
|
3 |
+
import traceback
|
4 |
+
from apps.components import chat_accordion
|
5 |
+
|
6 |
+
BOT_NAME = "ChatGLM2-6B"
|
7 |
+
TITLE = """<h3 align="center">🤖 通用对话</h3>"""
|
8 |
+
RETRY_COMMAND = "/retry"
|
9 |
+
|
10 |
+
|
11 |
+
def chat(client: ChatClient):
|
12 |
+
with gr.Row():
|
13 |
+
with gr.Column(elem_id="chat_container", scale=3):
|
14 |
+
with gr.Row():
|
15 |
+
chatbot = gr.Chatbot(elem_id="chatbot")
|
16 |
+
with gr.Row():
|
17 |
+
inputs = gr.Textbox(
|
18 |
+
placeholder=f"你好 {BOT_NAME} !",
|
19 |
+
label="输入内容后点击回车",
|
20 |
+
max_lines=3,
|
21 |
+
)
|
22 |
+
with gr.Row(elem_id="button_container"):
|
23 |
+
with gr.Column():
|
24 |
+
retry_button = gr.Button("♻️ 重试上一轮对话")
|
25 |
+
with gr.Column():
|
26 |
+
delete_turn_button = gr.Button("🧽 删除上一轮对话")
|
27 |
+
with gr.Column():
|
28 |
+
clear_chat_button = gr.Button("✨ 删除全部对话历史")
|
29 |
+
|
30 |
+
with gr.Column(elem_id="param_container", scale=1):
|
31 |
+
temperature, top_p = chat_accordion()
|
32 |
+
|
33 |
+
def run_chat(message: str, chat_history, temperature: float, top_p: float):
|
34 |
+
if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
|
35 |
+
yield chat_history
|
36 |
+
return
|
37 |
+
|
38 |
+
if message == RETRY_COMMAND and chat_history:
|
39 |
+
prev_turn = chat_history.pop(-1)
|
40 |
+
user_message, _ = prev_turn
|
41 |
+
message = user_message
|
42 |
+
|
43 |
+
# chat_history = chat_history + [[message, ""]]
|
44 |
+
try:
|
45 |
+
stream = client.simple_chat(
|
46 |
+
message,
|
47 |
+
chat_history,
|
48 |
+
temperature=temperature,
|
49 |
+
top_p=top_p,
|
50 |
+
)
|
51 |
+
for resp, history in stream:
|
52 |
+
chat_history = history
|
53 |
+
yield chat_history
|
54 |
+
except Exception as e:
|
55 |
+
if not chat_history:
|
56 |
+
chat_history = []
|
57 |
+
chat_history += [["有错误了", traceback.format_exc()]]
|
58 |
+
yield chat_history
|
59 |
+
|
60 |
+
def delete_last_turn(chat_history):
|
61 |
+
if chat_history:
|
62 |
+
chat_history.pop(-1)
|
63 |
+
return {chatbot: gr.update(value=chat_history)}
|
64 |
+
|
65 |
+
def run_retry(message: str, chat_history, temperature: float, top_p: float):
|
66 |
+
yield from run_chat(RETRY_COMMAND, chat_history, temperature, top_p)
|
67 |
+
|
68 |
+
def clear_chat():
|
69 |
+
return []
|
70 |
+
|
71 |
+
inputs.submit(
|
72 |
+
run_chat,
|
73 |
+
[inputs, chatbot, temperature, top_p],
|
74 |
+
outputs=[chatbot],
|
75 |
+
show_progress=False,
|
76 |
+
)
|
77 |
+
inputs.submit(lambda: "", inputs=None, outputs=inputs)
|
78 |
+
delete_turn_button.click(delete_last_turn, inputs=[chatbot], outputs=[chatbot])
|
79 |
+
retry_button.click(
|
80 |
+
run_retry,
|
81 |
+
[inputs, chatbot, temperature, top_p],
|
82 |
+
outputs=[chatbot],
|
83 |
+
show_progress=False,
|
84 |
+
)
|
85 |
+
clear_chat_button.click(clear_chat, [], chatbot)
|
86 |
+
|
87 |
+
|
88 |
+
def simple_chat_demo(client: ChatClient):
|
89 |
+
gr.HTML(TITLE)
|
90 |
+
chat(client)
|
apps/translator.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from apps.components import chat_accordion
|
6 |
+
|
7 |
+
IDEA_TITLE = "ChatGLM2-6B 翻译官"
|
8 |
+
|
9 |
+
prompt_tmpl = """imagine you are a professional translator. Your task is translating the text around by ``` to Chinese.
|
10 |
+
|
11 |
+
input text:
|
12 |
+
|
13 |
+
```
|
14 |
+
{input_text}
|
15 |
+
```
|
16 |
+
|
17 |
+
translation result:"""
|
18 |
+
|
19 |
+
|
20 |
+
def translator_demo(client):
|
21 |
+
|
22 |
+
def stream_translate(input_text, temperature: float, top_p: float):
|
23 |
+
if not input_text:
|
24 |
+
return None
|
25 |
+
message = prompt_tmpl.format(input_text=input_text)
|
26 |
+
try:
|
27 |
+
stream = client.simple_chat(
|
28 |
+
message,
|
29 |
+
[],
|
30 |
+
temperature=temperature,
|
31 |
+
top_p=top_p,
|
32 |
+
)
|
33 |
+
for resp, _ in stream:
|
34 |
+
pass
|
35 |
+
return resp
|
36 |
+
except Exception as e:
|
37 |
+
return traceback.format_exc()
|
38 |
+
|
39 |
+
def clear_content():
|
40 |
+
return None, None
|
41 |
+
|
42 |
+
with gr.Row():
|
43 |
+
with gr.Column():
|
44 |
+
inputs = gr.Textbox(label="请输入原文", max_lines=5)
|
45 |
+
gr.Dropdown(["en -> zh"], value="en -> zh", label="翻译语言")
|
46 |
+
temperature, top_p = chat_accordion()
|
47 |
+
with gr.Row(elem_id="button_container"):
|
48 |
+
with gr.Column():
|
49 |
+
commit_btn = gr.Button(value="翻译", variant='primary')
|
50 |
+
with gr.Column():
|
51 |
+
clear_btn = gr.Button(value="清空")
|
52 |
+
|
53 |
+
with gr.Column():
|
54 |
+
outputs = gr.Textbox(label="译文", max_lines=5)
|
55 |
+
|
56 |
+
commit_btn.click(stream_translate, inputs=[inputs, temperature, top_p], outputs=[outputs])
|
57 |
+
clear_btn.click(clear_content, inputs=None, outputs=[inputs, outputs])
|
chatClient.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from websockets.exceptions import ConnectionClosedOK
|
4 |
+
from websockets.sync.client import connect
|
5 |
+
|
6 |
+
from chatglm2_6b.modelClient import ChatGLM2
|
7 |
+
import abc
|
8 |
+
|
9 |
+
|
10 |
+
class ChatClient(abc.ABC):
|
11 |
+
@abc.abstractmethod
|
12 |
+
def simple_chat(self, query, history, temperature, top_p):
|
13 |
+
pass
|
14 |
+
|
15 |
+
@abc.abstractmethod
|
16 |
+
def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
def format_chat_prompt(message: str, chat_history, instructions: str) -> str:
|
21 |
+
instructions = instructions.strip(" ").strip("\n")
|
22 |
+
prompt = f"对话背景设定:{instructions}"
|
23 |
+
for i, (user_message, bot_message) in enumerate(chat_history):
|
24 |
+
prompt = f"{prompt}\n\n[Round {i + 1}]\n\n问:{user_message}\n\n答:{bot_message}"
|
25 |
+
prompt = f"{prompt}\n\n[Round {len(chat_history)+1}]\n\n问:{message}\n\n答:"
|
26 |
+
return prompt
|
27 |
+
|
28 |
+
|
29 |
+
class ChatGLM2APIClient(ChatClient):
|
30 |
+
def __init__(self, ws_url=None):
|
31 |
+
self.ws_url = "ws://localhost:10001"
|
32 |
+
if ws_url:
|
33 |
+
self.ws_url = ws_url
|
34 |
+
|
35 |
+
def simple_chat(self, query, history, temperature, top_p):
|
36 |
+
"""chatglm2-6b 模型定义的对话方法"""
|
37 |
+
url = f"{self.ws_url}/streamChat"
|
38 |
+
with connect(url) as websocket:
|
39 |
+
msg = json.dumps({
|
40 |
+
"query": query, "history": history,
|
41 |
+
"temperature": temperature, "top_p": top_p,
|
42 |
+
})
|
43 |
+
websocket.send(msg)
|
44 |
+
|
45 |
+
data = None
|
46 |
+
try:
|
47 |
+
while True:
|
48 |
+
data = websocket.recv()
|
49 |
+
data = json.loads(data)
|
50 |
+
yield data['resp'], data['history']
|
51 |
+
except ConnectionClosedOK:
|
52 |
+
print("generation is finished")
|
53 |
+
|
54 |
+
def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
|
55 |
+
"""基于chatglm2-6b text_generate 实现的基于预设指令的对话"""
|
56 |
+
url = f"{self.ws_url}/streamGenerate"
|
57 |
+
|
58 |
+
prompt = format_chat_prompt(message, chat_history, instructions)
|
59 |
+
chat_history = chat_history + [[message, ""]]
|
60 |
+
params = json.dumps({"prompt": prompt, "temperature": temperature, "top_p": top_p})
|
61 |
+
with connect(url) as websocket:
|
62 |
+
websocket.send(params)
|
63 |
+
|
64 |
+
data = None
|
65 |
+
try:
|
66 |
+
while True:
|
67 |
+
data = websocket.recv()
|
68 |
+
data = json.loads(data)
|
69 |
+
resp = data['text']
|
70 |
+
|
71 |
+
last_turn = list(chat_history.pop(-1))
|
72 |
+
last_turn[-1] = resp
|
73 |
+
chat_history = chat_history + [last_turn]
|
74 |
+
yield resp, chat_history
|
75 |
+
except ConnectionClosedOK:
|
76 |
+
print("generation is finished")
|
77 |
+
|
78 |
+
|
79 |
+
class ChatGLM2ModelClient(ChatClient):
|
80 |
+
def __init__(self, model_path=None):
|
81 |
+
self.model = ChatGLM2(model_path)
|
82 |
+
|
83 |
+
def simple_chat(self, query, history, temperature, top_p):
|
84 |
+
kwargs = {
|
85 |
+
"query": query, "history": history,
|
86 |
+
"temperature": temperature, "top_p": top_p,
|
87 |
+
}
|
88 |
+
for resp, history in self.model.stream_chat(**kwargs):
|
89 |
+
yield resp, history
|
90 |
+
|
91 |
+
def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
|
92 |
+
prompt = format_chat_prompt(message, chat_history, instructions)
|
93 |
+
chat_history = chat_history + [[message, ""]]
|
94 |
+
kwargs = {"prompt": prompt, "temperature": temperature, "top_p": top_p}
|
95 |
+
for resp in self.model.stream_generate(**kwargs):
|
96 |
+
last_turn = list(chat_history.pop(-1))
|
97 |
+
last_turn[-1] = resp
|
98 |
+
chat_history = chat_history + [last_turn]
|
99 |
+
yield resp, chat_history
|
chatglm2_6b/__init__.py
ADDED
File without changes
|
chatglm2_6b/modelClient.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
from transformers.generation.logits_process import LogitsProcessor
|
6 |
+
from transformers.generation.utils import LogitsProcessorList
|
7 |
+
|
8 |
+
DEFAULT_MODEL_PATH = "THUDM/chatglm2-6b"
|
9 |
+
|
10 |
+
|
11 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
12 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
13 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
14 |
+
scores.zero_()
|
15 |
+
scores[..., 5] = 5e4
|
16 |
+
return scores
|
17 |
+
|
18 |
+
|
19 |
+
class ChatGLM2(object):
|
20 |
+
def __init__(self, model_path=None):
|
21 |
+
if not model_path:
|
22 |
+
self.model_path = DEFAULT_MODEL_PATH
|
23 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
24 |
+
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
|
25 |
+
self.model = model.eval()
|
26 |
+
|
27 |
+
def generate(
|
28 |
+
self,
|
29 |
+
prompt: str,
|
30 |
+
do_sample: bool = True,
|
31 |
+
max_length: int = 8192,
|
32 |
+
num_beams: int = 1,
|
33 |
+
temperature: float = 0.8,
|
34 |
+
top_p: float = 0.8,
|
35 |
+
):
|
36 |
+
logits_processor = LogitsProcessorList()
|
37 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
38 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
39 |
+
"temperature": temperature, "logits_processor": logits_processor}
|
40 |
+
inputs = self.tokenizer([prompt], return_tensors="pt")
|
41 |
+
inputs = inputs.to(self.model.device)
|
42 |
+
outputs = self.model.generate(**inputs, **gen_kwargs)
|
43 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
44 |
+
response = self.tokenizer.decode(outputs)
|
45 |
+
response = self.model.process_response(response)
|
46 |
+
return response
|
47 |
+
|
48 |
+
def stream_generate(
|
49 |
+
self,
|
50 |
+
prompt: str,
|
51 |
+
do_sample: bool = True,
|
52 |
+
max_length: int = 8192,
|
53 |
+
temperature: float = 0.8,
|
54 |
+
top_p: float = 0.8,
|
55 |
+
):
|
56 |
+
logits_processor = LogitsProcessorList()
|
57 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
58 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
59 |
+
"temperature": temperature, "logits_processor": logits_processor}
|
60 |
+
inputs = self.tokenizer([prompt], return_tensors="pt")
|
61 |
+
inputs = inputs.to(self.model.device)
|
62 |
+
for outputs in self.model.stream_generate(**inputs, **gen_kwargs):
|
63 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
64 |
+
response = self.tokenizer.decode(outputs)
|
65 |
+
if response and response[-1] != "�":
|
66 |
+
response = self.model.process_response(response)
|
67 |
+
yield response
|
68 |
+
|
69 |
+
def stream_chat(
|
70 |
+
self,
|
71 |
+
query: str,
|
72 |
+
history: List[Tuple[str, str]],
|
73 |
+
max_length: int = 8192,
|
74 |
+
do_sample=True,
|
75 |
+
top_p=0.8,
|
76 |
+
temperature=0.8
|
77 |
+
):
|
78 |
+
stream = self.model.stream_chat(self.tokenizer, query, history,
|
79 |
+
max_length=max_length, do_sample=do_sample, top_p=top_p, temperature=temperature)
|
80 |
+
for resp, new_history in stream:
|
81 |
+
yield resp, new_history
|
82 |
+
|
chatglm2_6b/server.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import anyio
|
4 |
+
from fastapi import FastAPI, WebSocket
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
from chatglm2_6b.modelClient import ChatGLM2
|
8 |
+
from config import Settings
|
9 |
+
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
chat_glm2 = ChatGLM2(Settings.CHATGLM_MODEL_PATH)
|
13 |
+
|
14 |
+
|
15 |
+
class ChatParams(BaseModel):
|
16 |
+
prompt: str
|
17 |
+
do_sample: bool = True
|
18 |
+
max_length: int = 2048
|
19 |
+
temperature: float = 0.8
|
20 |
+
top_p: float = 0.8
|
21 |
+
|
22 |
+
|
23 |
+
@app.post("/generate")
|
24 |
+
def generate(params: ChatParams):
|
25 |
+
input_params = params.dict()
|
26 |
+
text = chat_glm2.generate(**input_params)
|
27 |
+
return {"text": text}
|
28 |
+
|
29 |
+
|
30 |
+
@app.websocket("/streamGenerate")
|
31 |
+
async def stream_generate(websocket: WebSocket):
|
32 |
+
await websocket.accept()
|
33 |
+
params = await websocket.receive_json()
|
34 |
+
func = functools.partial(chat_glm2.stream_generate, **params)
|
35 |
+
stream = await anyio.to_thread.run_sync(func)
|
36 |
+
for resp in stream:
|
37 |
+
await websocket.send_json({"text": resp})
|
38 |
+
await websocket.close()
|
39 |
+
|
40 |
+
|
41 |
+
@app.websocket("/streamChat")
|
42 |
+
async def stream_chat(websocket: WebSocket):
|
43 |
+
await websocket.accept()
|
44 |
+
params = await websocket.receive_json()
|
45 |
+
func = functools.partial(chat_glm2.stream_chat, **params)
|
46 |
+
stream = await anyio.to_thread.run_sync(func)
|
47 |
+
for resp, history in stream:
|
48 |
+
await websocket.send_json({"resp": resp, "history": history})
|
49 |
+
await websocket.close()
|
config.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
class Settings:
|
5 |
+
CHAT_CLIENT = os.environ.get('CHAT_CLIENT', "ChatGLM2APIClient")
|
6 |
+
MODEL_WS_URL = os.environ.get('MODEL_WS_URL', "ws://localhost:10001")
|
7 |
+
CHATGLM_MODEL_PATH = os.environ.get('CHATGLM_MODEL_PATH', "THUDM/chatglm2-6b")
|
gallery.gradio.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from apps import translator_demo, simple_chat_demo, instruction_chat_demo
|
3 |
+
from chatClient import ChatClient, ChatGLM2APIClient, ChatGLM2ModelClient
|
4 |
+
from config import Settings
|
5 |
+
|
6 |
+
TITLE = """<h2 align="center">🚀 ChatGLM2-6B apps gallery</h2>"""
|
7 |
+
|
8 |
+
demo_register = {
|
9 |
+
"通用对话": simple_chat_demo,
|
10 |
+
"预设指令对话": instruction_chat_demo,
|
11 |
+
"翻译器": translator_demo,
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
def get_gallery(client: ChatClient):
|
16 |
+
with gr.Blocks(
|
17 |
+
# css=None
|
18 |
+
# css="""#chat_container {width: 700px; margin-left: auto; margin-right: auto;}
|
19 |
+
# #button_container {width: 700px; margin-left: auto; margin-right: auto;}
|
20 |
+
# #param_container {width: 700px; margin-left: auto; margin-right: auto;}"""
|
21 |
+
css="""#chatbot {
|
22 |
+
font-size: 14px;
|
23 |
+
min-height: 300px;
|
24 |
+
}"""
|
25 |
+
) as demo:
|
26 |
+
gr.HTML(TITLE)
|
27 |
+
for name, demo_func in demo_register.items():
|
28 |
+
with gr.Tab(name):
|
29 |
+
demo_func(client)
|
30 |
+
return demo
|
31 |
+
|
32 |
+
|
33 |
+
def build_client():
|
34 |
+
client_class = Settings.CHAT_CLIENT
|
35 |
+
if client_class == 'ChatGLM2ModelClient':
|
36 |
+
|
37 |
+
return ChatGLM2ModelClient(Settings.CHATGLM_MODEL_PATH)
|
38 |
+
if client_class == 'ChatGLM2APIClient':
|
39 |
+
return ChatGLM2APIClient(Settings.MODEL_WS_URL)
|
40 |
+
raise Exception(f"Wrong ChatClient: {client_class}")
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
client = build_client()
|
45 |
+
demo = get_gallery(client)
|
46 |
+
demo.queue(max_size=128, concurrency_count=16)
|
47 |
+
demo.launch()
|
runserver.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
|
3 |
+
from chatglm2_6b.server import app
|
4 |
+
|
5 |
+
|
6 |
+
def runserver():
|
7 |
+
uvicorn.run(app, host="0.0.0.0", port=10001)
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == '__main__':
|
11 |
+
runserver()
|