Spaces:
Sleeping
Sleeping
Create demo.py
Browse files
demo.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A simple web interactive chat demo based on gradio."""
|
2 |
+
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from threading import Thread
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import torch
|
8 |
+
from transformers import (
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
AutoTokenizer,
|
11 |
+
StoppingCriteria,
|
12 |
+
StoppingCriteriaList,
|
13 |
+
TextIteratorStreamer,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class StopOnTokens(StoppingCriteria):
|
18 |
+
def __call__(
|
19 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
20 |
+
) -> bool:
|
21 |
+
stop_ids = (
|
22 |
+
[2, 6, 7, 8],
|
23 |
+
) # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
|
24 |
+
for stop_id in stop_ids:
|
25 |
+
if input_ids[0][-1] == stop_id:
|
26 |
+
return True
|
27 |
+
return False
|
28 |
+
|
29 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
30 |
+
def __init__(self, stops = [], encounters=1):
|
31 |
+
super().__init__()
|
32 |
+
self.stops = [stop.to("cuda") for stop in stops]
|
33 |
+
|
34 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
35 |
+
last_token = input_ids[0][-1]
|
36 |
+
for stop in self.stops:
|
37 |
+
if tokenizer.decode(stop) == tokenizer.decode(last_token):
|
38 |
+
return True
|
39 |
+
return False
|
40 |
+
|
41 |
+
|
42 |
+
def parse_text(text):
|
43 |
+
lines = text.split("\n")
|
44 |
+
lines = [line for line in lines if line != ""]
|
45 |
+
count = 0
|
46 |
+
for i, line in enumerate(lines):
|
47 |
+
if "```" in line:
|
48 |
+
count += 1
|
49 |
+
items = line.split("`")
|
50 |
+
if count % 2 == 1:
|
51 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
52 |
+
else:
|
53 |
+
lines[i] = f"<br></code></pre>"
|
54 |
+
else:
|
55 |
+
if i > 0:
|
56 |
+
if count % 2 == 1:
|
57 |
+
line = line.replace("`", "\`")
|
58 |
+
line = line.replace("<", "<")
|
59 |
+
line = line.replace(">", ">")
|
60 |
+
line = line.replace(" ", " ")
|
61 |
+
line = line.replace("*", "*")
|
62 |
+
line = line.replace("_", "_")
|
63 |
+
line = line.replace("-", "-")
|
64 |
+
line = line.replace(".", ".")
|
65 |
+
line = line.replace("!", "!")
|
66 |
+
line = line.replace("(", "(")
|
67 |
+
line = line.replace(")", ")")
|
68 |
+
line = line.replace("$", "$")
|
69 |
+
lines[i] = "<br>" + line
|
70 |
+
text = "".join(lines)
|
71 |
+
return text
|
72 |
+
|
73 |
+
|
74 |
+
def predict(history, max_length, top_p, temperature):
|
75 |
+
stop = StopOnTokens()
|
76 |
+
# messages = [{"role": "system", "content": "You are a helpful assistant"}]
|
77 |
+
messages = [{"role": "system", "content": ""}]
|
78 |
+
# messages = []
|
79 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
80 |
+
if idx == len(history) - 1 and not model_msg:
|
81 |
+
messages.append({"role": "user", "content": user_msg})
|
82 |
+
break
|
83 |
+
if user_msg:
|
84 |
+
messages.append({"role": "user", "content": user_msg})
|
85 |
+
if model_msg:
|
86 |
+
messages.append({"role": "assistant", "content": model_msg})
|
87 |
+
|
88 |
+
print("\n\n====conversation====\n", messages)
|
89 |
+
model_inputs = tokenizer.apply_chat_template(
|
90 |
+
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
|
91 |
+
).to(next(model.parameters()).device)
|
92 |
+
streamer = TextIteratorStreamer(
|
93 |
+
tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
|
94 |
+
)
|
95 |
+
|
96 |
+
# stop_words = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"]
|
97 |
+
stop_words = ["</s>"]
|
98 |
+
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
|
99 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
100 |
+
|
101 |
+
generate_kwargs = {
|
102 |
+
"input_ids": model_inputs,
|
103 |
+
"streamer": streamer,
|
104 |
+
"max_new_tokens": max_length,
|
105 |
+
"do_sample": True,
|
106 |
+
"top_p": top_p,
|
107 |
+
"temperature": temperature,
|
108 |
+
"stopping_criteria": stopping_criteria,
|
109 |
+
"repetition_penalty": 1.1,
|
110 |
+
}
|
111 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
112 |
+
t.start()
|
113 |
+
|
114 |
+
for new_token in streamer:
|
115 |
+
if new_token != "":
|
116 |
+
history[-1][1] += new_token
|
117 |
+
yield history
|
118 |
+
|
119 |
+
|
120 |
+
def main(args):
|
121 |
+
with gr.Blocks() as demo:
|
122 |
+
# gr.Markdown(
|
123 |
+
# """\
|
124 |
+
# <p align="center"><img src="https://raw.githubusercontent.com/01-ai/Yi/main/assets/img/Yi_logo_icon_light.svg" style="height: 80px"/><p>"""
|
125 |
+
# )
|
126 |
+
# gr.Markdown("""<center><font size=8>Yi-Chat Bot</center>""")
|
127 |
+
gr.Markdown("""<center><font size=8>🦣MAmmoTH2</center>""")
|
128 |
+
# gr.Markdown(
|
129 |
+
# """\
|
130 |
+
# <center><font size=3>This WebUI is based on Yi-Chat, developed by 01-AI.</center>"""
|
131 |
+
# )
|
132 |
+
gr.Markdown(
|
133 |
+
"""\
|
134 |
+
<center><font size=4>
|
135 |
+
MAmmoTH2-8x7B-Plus <a style="text-decoration: none" href="https://huggingface.co/TIGER-Lab/MAmmoTH2-8x7B-Plus/">🤗</a> """
|
136 |
+
# <a style="text-decoration: none" href="https://www.modelscope.cn/models/01ai/Yi-34B-Chat/summary">🤖</a> 
|
137 |
+
#  <a style="text-decoration: none" href="https://github.com/01-ai/Yi">Yi GitHub</a></center>
|
138 |
+
|
139 |
+
)
|
140 |
+
|
141 |
+
chatbot = gr.Chatbot()
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column(scale=4):
|
145 |
+
with gr.Column(scale=12):
|
146 |
+
user_input = gr.Textbox(
|
147 |
+
show_label=False,
|
148 |
+
placeholder="Input...",
|
149 |
+
lines=10,
|
150 |
+
container=False,
|
151 |
+
)
|
152 |
+
with gr.Column(min_width=32, scale=1):
|
153 |
+
submitBtn = gr.Button("🚀 Submit")
|
154 |
+
with gr.Column(scale=1):
|
155 |
+
emptyBtn = gr.Button("🧹 Clear History")
|
156 |
+
max_length = gr.Slider(
|
157 |
+
0,
|
158 |
+
32768,
|
159 |
+
value=4096,
|
160 |
+
step=1.0,
|
161 |
+
label="Maximum length",
|
162 |
+
interactive=True,
|
163 |
+
)
|
164 |
+
top_p = gr.Slider(
|
165 |
+
0, 1, value=1.0, step=0.01, label="Top P", interactive=True
|
166 |
+
)
|
167 |
+
temperature = gr.Slider(
|
168 |
+
0.01, 1, value=0.7, step=0.01, label="Temperature", interactive=True
|
169 |
+
)
|
170 |
+
|
171 |
+
def user(query, history):
|
172 |
+
# return "", history + [[parse_text(query), ""]]
|
173 |
+
return "", history + [[query, ""]]
|
174 |
+
|
175 |
+
submitBtn.click(
|
176 |
+
user, [user_input, chatbot], [user_input, chatbot], queue=False
|
177 |
+
).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
|
178 |
+
user_input.submit(
|
179 |
+
user, [user_input, chatbot], [user_input, chatbot], queue=False
|
180 |
+
).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
|
181 |
+
emptyBtn.click(lambda: None, None, chatbot, queue=False)
|
182 |
+
|
183 |
+
demo.queue()
|
184 |
+
|
185 |
+
demo.launch(
|
186 |
+
server_name=args.server_name,
|
187 |
+
server_port=args.server_port,
|
188 |
+
inbrowser=args.inbrowser,
|
189 |
+
share=args.share
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
parser = ArgumentParser()
|
195 |
+
parser.add_argument(
|
196 |
+
"-c",
|
197 |
+
"--checkpoint-path",
|
198 |
+
type=str,
|
199 |
+
default="TIGER-Lab/MAmmoTH2-8B-Plus",
|
200 |
+
help="Checkpoint name or path, default to %(default)r",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--cpu-only", action="store_true", help="Run demo with CPU only"
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
"--share",
|
207 |
+
action="store_true",
|
208 |
+
default=False,
|
209 |
+
help="Create a publicly shareable link for the interface.",
|
210 |
+
)
|
211 |
+
parser.add_argument(
|
212 |
+
"--inbrowser",
|
213 |
+
action="store_true",
|
214 |
+
default=True,
|
215 |
+
help="Automatically launch the interface in a new tab on the default browser.",
|
216 |
+
)
|
217 |
+
parser.add_argument(
|
218 |
+
"--server-port", type=int, default=8110, help="Demo server port."
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
|
222 |
+
)
|
223 |
+
|
224 |
+
args = parser.parse_args()
|
225 |
+
|
226 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
227 |
+
args.checkpoint_path, trust_remote_code=True
|
228 |
+
)
|
229 |
+
|
230 |
+
if args.cpu_only:
|
231 |
+
device_map = "cpu"
|
232 |
+
else:
|
233 |
+
device_map = "auto"
|
234 |
+
|
235 |
+
model = AutoModelForCausalLM.from_pretrained(
|
236 |
+
args.checkpoint_path,
|
237 |
+
device_map=device_map,
|
238 |
+
torch_dtype="auto",
|
239 |
+
trust_remote_code=True,
|
240 |
+
).eval()
|
241 |
+
|
242 |
+
main(args)
|