aaabiao commited on
Commit
96e2a3f
1 Parent(s): 72b7029

Create demo.py

Browse files
Files changed (1) hide show
  1. demo.py +242 -0
demo.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple web interactive chat demo based on gradio."""
2
+
3
+ from argparse import ArgumentParser
4
+ from threading import Thread
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ StoppingCriteria,
12
+ StoppingCriteriaList,
13
+ TextIteratorStreamer,
14
+ )
15
+
16
+
17
+ class StopOnTokens(StoppingCriteria):
18
+ def __call__(
19
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
20
+ ) -> bool:
21
+ stop_ids = (
22
+ [2, 6, 7, 8],
23
+ ) # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
24
+ for stop_id in stop_ids:
25
+ if input_ids[0][-1] == stop_id:
26
+ return True
27
+ return False
28
+
29
+ class StoppingCriteriaSub(StoppingCriteria):
30
+ def __init__(self, stops = [], encounters=1):
31
+ super().__init__()
32
+ self.stops = [stop.to("cuda") for stop in stops]
33
+
34
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
35
+ last_token = input_ids[0][-1]
36
+ for stop in self.stops:
37
+ if tokenizer.decode(stop) == tokenizer.decode(last_token):
38
+ return True
39
+ return False
40
+
41
+
42
+ def parse_text(text):
43
+ lines = text.split("\n")
44
+ lines = [line for line in lines if line != ""]
45
+ count = 0
46
+ for i, line in enumerate(lines):
47
+ if "```" in line:
48
+ count += 1
49
+ items = line.split("`")
50
+ if count % 2 == 1:
51
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
52
+ else:
53
+ lines[i] = f"<br></code></pre>"
54
+ else:
55
+ if i > 0:
56
+ if count % 2 == 1:
57
+ line = line.replace("`", "\`")
58
+ line = line.replace("<", "&lt;")
59
+ line = line.replace(">", "&gt;")
60
+ line = line.replace(" ", "&nbsp;")
61
+ line = line.replace("*", "&ast;")
62
+ line = line.replace("_", "&lowbar;")
63
+ line = line.replace("-", "&#45;")
64
+ line = line.replace(".", "&#46;")
65
+ line = line.replace("!", "&#33;")
66
+ line = line.replace("(", "&#40;")
67
+ line = line.replace(")", "&#41;")
68
+ line = line.replace("$", "&#36;")
69
+ lines[i] = "<br>" + line
70
+ text = "".join(lines)
71
+ return text
72
+
73
+
74
+ def predict(history, max_length, top_p, temperature):
75
+ stop = StopOnTokens()
76
+ # messages = [{"role": "system", "content": "You are a helpful assistant"}]
77
+ messages = [{"role": "system", "content": ""}]
78
+ # messages = []
79
+ for idx, (user_msg, model_msg) in enumerate(history):
80
+ if idx == len(history) - 1 and not model_msg:
81
+ messages.append({"role": "user", "content": user_msg})
82
+ break
83
+ if user_msg:
84
+ messages.append({"role": "user", "content": user_msg})
85
+ if model_msg:
86
+ messages.append({"role": "assistant", "content": model_msg})
87
+
88
+ print("\n\n====conversation====\n", messages)
89
+ model_inputs = tokenizer.apply_chat_template(
90
+ messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
91
+ ).to(next(model.parameters()).device)
92
+ streamer = TextIteratorStreamer(
93
+ tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
94
+ )
95
+
96
+ # stop_words = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"]
97
+ stop_words = ["</s>"]
98
+ stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
99
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
100
+
101
+ generate_kwargs = {
102
+ "input_ids": model_inputs,
103
+ "streamer": streamer,
104
+ "max_new_tokens": max_length,
105
+ "do_sample": True,
106
+ "top_p": top_p,
107
+ "temperature": temperature,
108
+ "stopping_criteria": stopping_criteria,
109
+ "repetition_penalty": 1.1,
110
+ }
111
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
112
+ t.start()
113
+
114
+ for new_token in streamer:
115
+ if new_token != "":
116
+ history[-1][1] += new_token
117
+ yield history
118
+
119
+
120
+ def main(args):
121
+ with gr.Blocks() as demo:
122
+ # gr.Markdown(
123
+ # """\
124
+ # <p align="center"><img src="https://raw.githubusercontent.com/01-ai/Yi/main/assets/img/Yi_logo_icon_light.svg" style="height: 80px"/><p>"""
125
+ # )
126
+ # gr.Markdown("""<center><font size=8>Yi-Chat Bot</center>""")
127
+ gr.Markdown("""<center><font size=8>🦣MAmmoTH2</center>""")
128
+ # gr.Markdown(
129
+ # """\
130
+ # <center><font size=3>This WebUI is based on Yi-Chat, developed by 01-AI.</center>"""
131
+ # )
132
+ gr.Markdown(
133
+ """\
134
+ <center><font size=4>
135
+ MAmmoTH2-8x7B-Plus <a style="text-decoration: none" href="https://huggingface.co/TIGER-Lab/MAmmoTH2-8x7B-Plus/">🤗</a> """
136
+ # <a style="text-decoration: none" href="https://www.modelscope.cn/models/01ai/Yi-34B-Chat/summary">🤖</a>&nbsp
137
+ # &nbsp<a style="text-decoration: none" href="https://github.com/01-ai/Yi">Yi GitHub</a></center>
138
+
139
+ )
140
+
141
+ chatbot = gr.Chatbot()
142
+
143
+ with gr.Row():
144
+ with gr.Column(scale=4):
145
+ with gr.Column(scale=12):
146
+ user_input = gr.Textbox(
147
+ show_label=False,
148
+ placeholder="Input...",
149
+ lines=10,
150
+ container=False,
151
+ )
152
+ with gr.Column(min_width=32, scale=1):
153
+ submitBtn = gr.Button("🚀 Submit")
154
+ with gr.Column(scale=1):
155
+ emptyBtn = gr.Button("🧹 Clear History")
156
+ max_length = gr.Slider(
157
+ 0,
158
+ 32768,
159
+ value=4096,
160
+ step=1.0,
161
+ label="Maximum length",
162
+ interactive=True,
163
+ )
164
+ top_p = gr.Slider(
165
+ 0, 1, value=1.0, step=0.01, label="Top P", interactive=True
166
+ )
167
+ temperature = gr.Slider(
168
+ 0.01, 1, value=0.7, step=0.01, label="Temperature", interactive=True
169
+ )
170
+
171
+ def user(query, history):
172
+ # return "", history + [[parse_text(query), ""]]
173
+ return "", history + [[query, ""]]
174
+
175
+ submitBtn.click(
176
+ user, [user_input, chatbot], [user_input, chatbot], queue=False
177
+ ).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
178
+ user_input.submit(
179
+ user, [user_input, chatbot], [user_input, chatbot], queue=False
180
+ ).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
181
+ emptyBtn.click(lambda: None, None, chatbot, queue=False)
182
+
183
+ demo.queue()
184
+
185
+ demo.launch(
186
+ server_name=args.server_name,
187
+ server_port=args.server_port,
188
+ inbrowser=args.inbrowser,
189
+ share=args.share
190
+ )
191
+
192
+
193
+ if __name__ == "__main__":
194
+ parser = ArgumentParser()
195
+ parser.add_argument(
196
+ "-c",
197
+ "--checkpoint-path",
198
+ type=str,
199
+ default="TIGER-Lab/MAmmoTH2-8B-Plus",
200
+ help="Checkpoint name or path, default to %(default)r",
201
+ )
202
+ parser.add_argument(
203
+ "--cpu-only", action="store_true", help="Run demo with CPU only"
204
+ )
205
+ parser.add_argument(
206
+ "--share",
207
+ action="store_true",
208
+ default=False,
209
+ help="Create a publicly shareable link for the interface.",
210
+ )
211
+ parser.add_argument(
212
+ "--inbrowser",
213
+ action="store_true",
214
+ default=True,
215
+ help="Automatically launch the interface in a new tab on the default browser.",
216
+ )
217
+ parser.add_argument(
218
+ "--server-port", type=int, default=8110, help="Demo server port."
219
+ )
220
+ parser.add_argument(
221
+ "--server-name", type=str, default="127.0.0.1", help="Demo server name."
222
+ )
223
+
224
+ args = parser.parse_args()
225
+
226
+ tokenizer = AutoTokenizer.from_pretrained(
227
+ args.checkpoint_path, trust_remote_code=True
228
+ )
229
+
230
+ if args.cpu_only:
231
+ device_map = "cpu"
232
+ else:
233
+ device_map = "auto"
234
+
235
+ model = AutoModelForCausalLM.from_pretrained(
236
+ args.checkpoint_path,
237
+ device_map=device_map,
238
+ torch_dtype="auto",
239
+ trust_remote_code=True,
240
+ ).eval()
241
+
242
+ main(args)