chensh123 commited on
Commit
e44ba7b
1 Parent(s): 56e1f20

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +100 -0
  2. cli_demo.py +58 -0
  3. gitattributes +35 -0
  4. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install dashscope')
3
+ import gradio as gr
4
+ from http import HTTPStatus
5
+ import dashscope
6
+ from dashscope import Generation
7
+ from dashscope.api_entities.dashscope_response import Role
8
+ from typing import List, Optional, Tuple, Dict
9
+ from urllib.error import HTTPError
10
+ default_system = 'You are a helpful assistant.'
11
+
12
+
13
+ YOUR_API_TOKEN = os.getenv('YOUR_API_TOKEN')
14
+ dashscope.api_key = YOUR_API_TOKEN
15
+
16
+ History = List[Tuple[str, str]]
17
+ Messages = List[Dict[str, str]]
18
+
19
+ def clear_session() -> History:
20
+ return '', []
21
+
22
+ def modify_system_session(system: str) -> str:
23
+ if system is None or len(system) == 0:
24
+ system = default_system
25
+ return system, system, []
26
+
27
+ def history_to_messages(history: History, system: str) -> Messages:
28
+ messages = [{'role': Role.SYSTEM, 'content': system}]
29
+ for h in history:
30
+ messages.append({'role': Role.USER, 'content': h[0]})
31
+ messages.append({'role': Role.ASSISTANT, 'content': h[1]})
32
+ return messages
33
+
34
+
35
+ def messages_to_history(messages: Messages) -> Tuple[str, History]:
36
+ assert messages[0]['role'] == Role.SYSTEM
37
+ system = messages[0]['content']
38
+ history = []
39
+ for q, r in zip(messages[1::2], messages[2::2]):
40
+ history.append([q['content'], r['content']])
41
+ return system, history
42
+
43
+
44
+ def model_chat(query: Optional[str], history: Optional[History], system: str
45
+ ) -> Tuple[str, str, History]:
46
+ if query is None:
47
+ query = ''
48
+ if history is None:
49
+ history = []
50
+ messages = history_to_messages(history, system)
51
+ messages.append({'role': Role.USER, 'content': query})
52
+ gen = Generation.call(
53
+ model = "qwen1.5-32b-chat",
54
+ messages=messages,
55
+ result_format='message',
56
+ stream=True
57
+ )
58
+ for response in gen:
59
+ if response.status_code == HTTPStatus.OK:
60
+ role = response.output.choices[0].message.role
61
+ response = response.output.choices[0].message.content
62
+ system, history = messages_to_history(messages + [{'role': role, 'content': response}])
63
+ yield '', history, system
64
+ else:
65
+ raise HTTPError('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
66
+ response.request_id, response.status_code,
67
+ response.code, response.message
68
+ ))
69
+
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("""<p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>""") ## todo
73
+ gr.Markdown("""<center><font size=8>Qwen-32B-Chat Bot👾</center>""")
74
+ gr.Markdown("""<center><font size=4>通义千问-32B(Qwen-32B) 是阿里云研发的通义千问大模型系列的720亿参数规模的模型。</center>""")
75
+
76
+ with gr.Row():
77
+ with gr.Column(scale=3):
78
+ system_input = gr.Textbox(value=default_system, lines=1, label='System')
79
+ with gr.Column(scale=1):
80
+ modify_system = gr.Button("🛠️ 设置system并清除历史对话", scale=2)
81
+ system_state = gr.Textbox(value=default_system, visible=False)
82
+ chatbot = gr.Chatbot(label='Qwen-32B-Chat')
83
+ textbox = gr.Textbox(lines=2, label='Input')
84
+
85
+ with gr.Row():
86
+ clear_history = gr.Button("🧹 清除历史对话")
87
+ sumbit = gr.Button("🚀 发送")
88
+
89
+ sumbit.click(model_chat,
90
+ inputs=[textbox, chatbot, system_state],
91
+ outputs=[textbox, chatbot, system_input])
92
+ clear_history.click(fn=clear_session,
93
+ inputs=[],
94
+ outputs=[textbox, chatbot])
95
+ modify_system.click(fn=modify_system_session,
96
+ inputs=[system_input],
97
+ outputs=[system_state, system_input, chatbot])
98
+
99
+ demo.queue(api_open=False)
100
+ demo.launch(max_threads=30)
cli_demo.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import signal
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import readline
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
8
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
9
+ model = model.eval()
10
+
11
+ os_name = platform.system()
12
+ clear_command = 'cls' if os_name == 'Windows' else 'clear'
13
+ stop_stream = False
14
+
15
+
16
+ def build_prompt(history):
17
+ prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
18
+ for query, response in history:
19
+ prompt += f"\n\n用户:{query}"
20
+ prompt += f"\n\nChatGLM-6B:{response}"
21
+ return prompt
22
+
23
+
24
+ def signal_handler(signal, frame):
25
+ global stop_stream
26
+ stop_stream = True
27
+
28
+
29
+ def main():
30
+ history = []
31
+ global stop_stream
32
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
33
+ while True:
34
+ query = input("\n用户:")
35
+ if query.strip() == "stop":
36
+ break
37
+ if query.strip() == "clear":
38
+ history = []
39
+ os.system(clear_command)
40
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
41
+ continue
42
+ count = 0
43
+ for response, history in model.stream_chat(tokenizer, query, history=history):
44
+ if stop_stream:
45
+ stop_stream = False
46
+ break
47
+ else:
48
+ count += 1
49
+ if count % 8 == 0:
50
+ os.system(clear_command)
51
+ print(build_prompt(history), flush=True)
52
+ signal.signal(signal.SIGINT, signal_handler)
53
+ os.system(clear_command)
54
+ print(build_prompt(history), flush=True)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio==4.29.0
2
+ gradio_client