yentinglin commited on
Commit
ab13bd6
·
1 Parent(s): be6cc22

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from text_generation import Client
3
+ # text-generation 0.6.0
4
+
5
+ eos_token = "</s>"
6
+
7
+ def _concat_messages(messages):
8
+ message_text = ""
9
+ for message in messages:
10
+ if message["role"] == "system":
11
+ message_text += "<|system|>\n" + message["content"].strip() + "\n"
12
+ elif message["role"] == "user":
13
+ message_text += "<|user|>\n" + message["content"].strip() + "\n"
14
+ elif message["role"] == "assistant":
15
+ message_text += "<|assistant|>\n" + message["content"].strip() + eos_token + "\n"
16
+ else:
17
+ raise ValueError("Invalid role: {}".format(message["role"]))
18
+ return message_text
19
+
20
+ endpoint_url = "http://ec2-52-193-118-191.ap-northeast-1.compute.amazonaws.com:8080"
21
+ client = Client(endpoint_url, timeout=120)
22
+
23
+ def generate_response(user_input, max_new_token: 100, top_p, temperature, top_k, do_sample, repetition_penalty):
24
+ msg = _concat_messages([
25
+ {"role": "system", "content": "你是一個由國立台灣大學的NLP實驗室開發的大型語言模型。你基於Transformer架構被訓練,並已經經過大量的台灣中文語料庫的訓練。你的設計目標是理解和生成優雅的繁體中文,並具有跨語境和跨領域的對話能力。使用者可以向你提問任何問題或提出任何話題,並期待從你那裡得到高質量的回答。你應該要盡量幫助使用者解決問題,提供他們需要的資訊,並在適當時候給予建議。"},
26
+ {"role": "user", "content": user_input},
27
+ ])
28
+ msg += "<|assistant|>\n"
29
+ res = client.generate(msg, stop_sequences=["<|assistant|>", eos_token, "<|system|>", "<|user|>"],
30
+ max_new_tokens=1000)
31
+ return [("assistant", res.generated_text)]
32
+
33
+ with gr.Blocks() as demo:
34
+ # github_banner_path = 'https://raw.githubusercontent.com/ymcui/Chinese-LLaMA-Alpaca/main/pics/banner.png'
35
+ # gr.HTML(f'<p align="center"><a href="https://github.com/ymcui/Chinese-LLaMA-Alpaca"><img src={github_banner_path} width="700"/></a></p>')
36
+ # gr.Markdown("> 为了促进大模型在中文NLP社区的开放研究,本项目开源了中文LLaMA模型和指令精调的Alpaca大模型。这些模型在原版LLaMA的基础上扩充了中文词表并使用了中文数据进行二次预训练,进一步提升了中文基础语义理解能力。同时,中文Alpaca模型进一步使用了中文指令数据进行精调,显著提升了模型对指令的理解和执行能力。")
37
+ chatbot = gr.Chatbot()
38
+ with gr.Row():
39
+ with gr.Column(scale=4):
40
+ with gr.Column(scale=12):
41
+ user_input = gr.Textbox(
42
+ show_label=False,
43
+ placeholder="Shift + Enter发送消息...",
44
+ lines=10).style(
45
+ container=False)
46
+ with gr.Column(min_width=32, scale=1):
47
+ submitBtn = gr.Button("Submit", variant="primary")
48
+ with gr.Column(scale=1):
49
+ emptyBtn = gr.Button("Clear History")
50
+ max_new_token = gr.Slider(
51
+ 0,
52
+ 4096,
53
+ value=512,
54
+ step=1.0,
55
+ label="Maximum New Token Length",
56
+ interactive=True)
57
+ top_p = gr.Slider(0, 1, value=0.9, step=0.01,
58
+ label="Top P", interactive=True)
59
+ temperature = gr.Slider(
60
+ 0,
61
+ 1,
62
+ value=0.5,
63
+ step=0.01,
64
+ label="Temperature",
65
+ interactive=True)
66
+ top_k = gr.Slider(1, 40, value=40, step=1,
67
+ label="Top K", interactive=True)
68
+ do_sample = gr.Checkbox(
69
+ value=True,
70
+ label="Do Sample",
71
+ info="use random sample strategy",
72
+ interactive=True)
73
+ repetition_penalty = gr.Slider(
74
+ 1.0,
75
+ 3.0,
76
+ value=1.1,
77
+ step=0.1,
78
+ label="Repetition Penalty",
79
+ interactive=True)
80
+
81
+ params = [user_input, chatbot]
82
+ predict_params = [
83
+ chatbot,
84
+ max_new_token,
85
+ top_p,
86
+ temperature,
87
+ top_k,
88
+ do_sample,
89
+ repetition_penalty]
90
+
91
+ submitBtn.click(
92
+ generate_response,
93
+ [user_input],
94
+ [chatbot],
95
+ queue=False).then(
96
+ None,
97
+ None,
98
+ [user_input],
99
+ queue=False)
100
+
101
+ user_input.submit(
102
+ generate_response,
103
+ [user_input],
104
+ [chatbot],
105
+ queue=False).then(
106
+ None,
107
+ None,
108
+ [user_input],
109
+ queue=False)
110
+
111
+ submitBtn.click(lambda: None, [], [user_input])
112
+
113
+ emptyBtn.click(lambda: chatbot.reset(), outputs=[chatbot], show_progress=True)
114
+
115
+ demo.launch(share=True)