Spaces:
Runtime error
Runtime error
wangrongsheng
commited on
Commit
•
b62eec7
1
Parent(s):
2ff72d7
Upload 4 files
Browse files- app.py +235 -293
- model.py +74 -0
- requirements.txt +8 -9
- style.css +16 -0
app.py
CHANGED
@@ -1,329 +1,271 @@
|
|
1 |
-
|
2 |
-
# pylint: disable=broad-exception-caught, redefined-outer-name, missing-function-docstring, missing-module-docstring, too-many-arguments, line-too-long, invalid-name, redefined-builtin, redefined-argument-from-local
|
3 |
-
# import gradio as gr
|
4 |
|
5 |
-
|
6 |
-
|
7 |
|
8 |
-
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
import torch
|
17 |
-
from loguru import logger
|
18 |
-
from transformers import AutoModel, AutoTokenizer
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
try:
|
23 |
-
time.tzset() # type: ignore # pylint: disable=no-member
|
24 |
-
except Exception:
|
25 |
-
# Windows
|
26 |
-
logger.warning("Windows, cant run time.tzset()")
|
27 |
|
28 |
-
|
29 |
-
#model_name = "OpenMEDLab/PULSE-7bv5"
|
30 |
|
31 |
-
|
|
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
model = model.eval()
|
37 |
|
38 |
-
|
|
|
39 |
|
40 |
|
41 |
-
def
|
42 |
-
|
43 |
-
return []
|
44 |
-
for i, (message, response) in enumerate(y):
|
45 |
-
y[i] = (
|
46 |
-
None if message is None else mdtex2html.convert((message)),
|
47 |
-
None if response is None else mdtex2html.convert(response),
|
48 |
-
)
|
49 |
-
return y
|
50 |
-
|
51 |
-
|
52 |
-
gr.Chatbot.postprocess = postprocess
|
53 |
-
|
54 |
-
|
55 |
-
def parse_text(text):
|
56 |
-
lines = text.split("\n")
|
57 |
-
lines = [line for line in lines if line != ""]
|
58 |
-
count = 0
|
59 |
-
for i, line in enumerate(lines):
|
60 |
-
if "```" in line:
|
61 |
-
count += 1
|
62 |
-
items = line.split("`")
|
63 |
-
if count % 2 == 1:
|
64 |
-
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
65 |
-
else:
|
66 |
-
lines[i] = "<br></code></pre>"
|
67 |
-
else:
|
68 |
-
if i > 0:
|
69 |
-
if count % 2 == 1:
|
70 |
-
line = line.replace("`", r"\`")
|
71 |
-
line = line.replace("<", "<")
|
72 |
-
line = line.replace(">", ">")
|
73 |
-
line = line.replace(" ", " ")
|
74 |
-
line = line.replace("*", "*")
|
75 |
-
line = line.replace("_", "_")
|
76 |
-
line = line.replace("-", "-")
|
77 |
-
line = line.replace(".", ".")
|
78 |
-
line = line.replace("!", "!")
|
79 |
-
line = line.replace("(", "(")
|
80 |
-
line = line.replace(")", ")")
|
81 |
-
line = line.replace("$", "$")
|
82 |
-
lines[i] = "<br>" + line
|
83 |
-
text = "".join(lines)
|
84 |
-
return text
|
85 |
-
|
86 |
-
|
87 |
-
def predict(
|
88 |
-
RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
|
89 |
-
):
|
90 |
-
try:
|
91 |
-
chatbot.append((parse_text(input), ""))
|
92 |
-
except Exception as exc:
|
93 |
-
logger.error(exc)
|
94 |
-
logger.debug(f"{chatbot=}")
|
95 |
-
_ = """
|
96 |
-
if chatbot:
|
97 |
-
chatbot[-1] = (parse_text(input), str(exc))
|
98 |
-
yield chatbot, history, past_key_values
|
99 |
-
# """
|
100 |
-
yield chatbot, history, past_key_values
|
101 |
-
"""
|
102 |
-
for response, history, past_key_values in model.stream_chat(
|
103 |
-
tokenizer,
|
104 |
-
input,
|
105 |
-
history,
|
106 |
-
past_key_values=past_key_values,
|
107 |
-
return_past_key_values=True,
|
108 |
-
max_length=max_length,
|
109 |
-
top_p=top_p,
|
110 |
-
temperature=temperature,
|
111 |
-
):
|
112 |
-
"""
|
113 |
-
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
114 |
-
temperature=temperature):
|
115 |
-
chatbot[-1] = (parse_text(input), parse_text(response))
|
116 |
-
|
117 |
-
yield chatbot, history, past_key_values
|
118 |
-
|
119 |
-
|
120 |
-
def trans_api(input, max_length=40960, top_p=0.7, temperature=0.95):
|
121 |
-
if max_length < 10:
|
122 |
-
max_length = 40960
|
123 |
-
if top_p < 0.1 or top_p > 1:
|
124 |
-
top_p = 0.7
|
125 |
-
if temperature <= 0 or temperature > 1:
|
126 |
-
temperature = 0.01
|
127 |
-
try:
|
128 |
-
res, _ = model.chat(
|
129 |
-
tokenizer,
|
130 |
-
input,
|
131 |
-
history=[],
|
132 |
-
past_key_values=None,
|
133 |
-
max_length=max_length,
|
134 |
-
top_p=top_p,
|
135 |
-
temperature=temperature,
|
136 |
-
)
|
137 |
-
# logger.debug(f"{res=} \n{_=}")
|
138 |
-
except Exception as exc:
|
139 |
-
logger.error(f"{exc=}")
|
140 |
-
res = str(exc)
|
141 |
-
|
142 |
-
return res
|
143 |
-
|
144 |
-
|
145 |
-
def reset_user_input():
|
146 |
-
return gr.update(value="")
|
147 |
-
|
148 |
-
|
149 |
-
def reset_state():
|
150 |
-
return [], [], None
|
151 |
-
|
152 |
-
|
153 |
-
# Delete last turn
|
154 |
-
def delete_last_turn(chat, history):
|
155 |
-
if chat and history:
|
156 |
-
chat.pop(-1)
|
157 |
-
history.pop(-1)
|
158 |
-
return chat, history
|
159 |
-
|
160 |
-
|
161 |
-
# Regenerate response
|
162 |
-
def retry_last_answer(
|
163 |
-
user_input, chatbot, max_length, top_p, temperature, history, past_key_values
|
164 |
-
):
|
165 |
-
if chatbot and history:
|
166 |
-
# Removing the previous conversation from chat
|
167 |
-
chatbot.pop(-1)
|
168 |
-
# Setting up a flag to capture a retry
|
169 |
-
RETRY_FLAG = True
|
170 |
-
# Getting last message from user
|
171 |
-
user_input = history[-1][0]
|
172 |
-
# Removing bot response from the history
|
173 |
-
history.pop(-1)
|
174 |
-
|
175 |
-
yield from predict(
|
176 |
-
RETRY_FLAG, # type: ignore
|
177 |
-
user_input,
|
178 |
-
chatbot,
|
179 |
-
max_length,
|
180 |
-
top_p,
|
181 |
-
temperature,
|
182 |
-
history,
|
183 |
-
past_key_values,
|
184 |
-
)
|
185 |
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
)
|
192 |
|
193 |
-
with gr.Accordion("🎈 Info", open=False):
|
194 |
-
_ = f"""
|
195 |
-
## 欢迎体验IvyGPT
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
chatbot,
|
242 |
-
|
243 |
-
|
244 |
temperature,
|
245 |
-
|
246 |
-
|
247 |
],
|
248 |
-
|
249 |
-
|
250 |
)
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
chatbot,
|
257 |
-
|
258 |
-
|
259 |
temperature,
|
260 |
-
|
261 |
-
|
262 |
],
|
263 |
-
|
264 |
-
|
265 |
-
api_name="predict",
|
266 |
-
)
|
267 |
-
submitBtn.click(reset_user_input, [], [user_input])
|
268 |
-
|
269 |
-
emptyBtn.click(
|
270 |
-
reset_state, outputs=[chatbot, history, past_key_values], show_progress="full"
|
271 |
)
|
272 |
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
inputs=[
|
276 |
-
|
277 |
chatbot,
|
278 |
-
|
279 |
-
|
280 |
temperature,
|
281 |
-
|
282 |
-
|
283 |
],
|
284 |
-
|
285 |
-
|
286 |
)
|
287 |
-
deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
|
288 |
-
|
289 |
-
with gr.Accordion("Example inputs", open=True):
|
290 |
-
examples = gr.Examples(
|
291 |
-
examples=[
|
292 |
-
["熬夜对身体有什么危害? "],
|
293 |
-
["新冠肺炎怎么预防"],
|
294 |
-
["系统性红斑狼疮的危害和治疗方法是什么?"],
|
295 |
-
],
|
296 |
-
inputs=[user_input],
|
297 |
-
examples_per_page=50,
|
298 |
-
)
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
input_text.submit(
|
313 |
-
trans_api,
|
314 |
-
[input_text, max_length, top_p, temperature],
|
315 |
-
out_text,
|
316 |
-
show_progress="full",
|
317 |
-
api_name="tr1",
|
318 |
)
|
319 |
-
# """
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
# reduce to 5 if OOM occurs to often
|
328 |
|
329 |
-
demo.queue(
|
|
|
1 |
+
from typing import Iterator
|
|
|
|
|
2 |
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
|
6 |
+
from model import get_input_token_length, run
|
7 |
|
8 |
+
DEFAULT_SYSTEM_PROMPT = """\
|
9 |
+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
|
10 |
+
"""
|
11 |
+
MAX_MAX_NEW_TOKENS = 2048
|
12 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
13 |
+
MAX_INPUT_TOKEN_LENGTH = 4000
|
14 |
|
15 |
+
DESCRIPTION = """
|
16 |
+
# CareLlama-关怀羊驼
|
|
|
|
|
|
|
17 |
|
18 |
+
- CareLlama (关怀羊驼)是一个医疗大语言模型,同时它集合了数十个公开可用的医疗微调数据集和开放可用的医疗大语言模型以促进医疗LLM快速发展。
|
19 |
+
- Medical LLM, Open Source Driven for a Healthy Future.
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
"""
|
|
|
22 |
|
23 |
+
LICENSE = """
|
24 |
+
<p/>
|
25 |
|
26 |
+
---
|
27 |
+
本项目相关资源仅供学术研究之用,严禁用于商业用途。使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目无法对其准确性作出保证。即使本项目模型输出符合医学事实,也不能被用作实际医学诊断的依据。对于模型输出的任何内容,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。
|
28 |
+
"""
|
|
|
29 |
|
30 |
+
if not torch.cuda.is_available():
|
31 |
+
DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
|
32 |
|
33 |
|
34 |
+
def clear_and_save_textbox(message: str) -> tuple[str, str]:
|
35 |
+
return '', message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
+
def display_input(message: str,
|
39 |
+
history: list[tuple[str, str]]) -> list[tuple[str, str]]:
|
40 |
+
history.append((message, ''))
|
41 |
+
return history
|
|
|
42 |
|
|
|
|
|
|
|
43 |
|
44 |
+
def delete_prev_fn(
|
45 |
+
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
|
46 |
+
try:
|
47 |
+
message, _ = history.pop()
|
48 |
+
except IndexError:
|
49 |
+
message = ''
|
50 |
+
return history, message or ''
|
51 |
+
|
52 |
+
|
53 |
+
def generate(
|
54 |
+
message: str,
|
55 |
+
history_with_input: list[tuple[str, str]],
|
56 |
+
system_prompt: str,
|
57 |
+
max_new_tokens: int,
|
58 |
+
temperature: float,
|
59 |
+
top_p: float,
|
60 |
+
top_k: int,
|
61 |
+
) -> Iterator[list[tuple[str, str]]]:
|
62 |
+
if max_new_tokens > MAX_MAX_NEW_TOKENS:
|
63 |
+
raise ValueError
|
64 |
+
|
65 |
+
history = history_with_input[:-1]
|
66 |
+
generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
|
67 |
+
try:
|
68 |
+
first_response = next(generator)
|
69 |
+
yield history + [(message, first_response)]
|
70 |
+
except StopIteration:
|
71 |
+
yield history + [(message, '')]
|
72 |
+
for response in generator:
|
73 |
+
yield history + [(message, response)]
|
74 |
+
|
75 |
+
|
76 |
+
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
|
77 |
+
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
|
78 |
+
for x in generator:
|
79 |
+
pass
|
80 |
+
return '', x
|
81 |
+
|
82 |
+
|
83 |
+
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
|
84 |
+
input_token_length = get_input_token_length(message, chat_history, system_prompt)
|
85 |
+
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
86 |
+
raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
|
87 |
+
|
88 |
+
|
89 |
+
with gr.Blocks(css='style.css') as demo:
|
90 |
+
gr.Markdown(DESCRIPTION)
|
91 |
+
gr.DuplicateButton(value='Duplicate Space for private use',
|
92 |
+
elem_id='duplicate-button')
|
93 |
+
|
94 |
+
with gr.Group():
|
95 |
+
chatbot = gr.Chatbot(label='CareLlama')
|
96 |
+
with gr.Row():
|
97 |
+
textbox = gr.Textbox(
|
98 |
+
container=False,
|
99 |
+
show_label=False,
|
100 |
+
placeholder='请输入内容...',
|
101 |
+
scale=10,
|
102 |
)
|
103 |
+
submit_button = gr.Button('Submit',
|
104 |
+
variant='primary',
|
105 |
+
scale=1,
|
106 |
+
min_width=0)
|
107 |
+
with gr.Row():
|
108 |
+
retry_button = gr.Button('🔄 重试', variant='secondary')
|
109 |
+
undo_button = gr.Button('↩️ 撤销', variant='secondary')
|
110 |
+
clear_button = gr.Button('🗑️ 清除', variant='secondary')
|
111 |
+
|
112 |
+
saved_input = gr.State()
|
113 |
+
|
114 |
+
with gr.Accordion(label='Advanced options', open=False):
|
115 |
+
system_prompt = gr.Textbox(label='System prompt',
|
116 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
117 |
+
lines=6)
|
118 |
+
max_new_tokens = gr.Slider(
|
119 |
+
label='Max new tokens',
|
120 |
+
minimum=1,
|
121 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
122 |
+
step=1,
|
123 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
124 |
+
)
|
125 |
+
temperature = gr.Slider(
|
126 |
+
label='Temperature',
|
127 |
+
minimum=0.1,
|
128 |
+
maximum=4.0,
|
129 |
+
step=0.1,
|
130 |
+
value=1.0,
|
131 |
+
)
|
132 |
+
top_p = gr.Slider(
|
133 |
+
label='Top-p (nucleus sampling)',
|
134 |
+
minimum=0.05,
|
135 |
+
maximum=1.0,
|
136 |
+
step=0.05,
|
137 |
+
value=0.95,
|
138 |
+
)
|
139 |
+
top_k = gr.Slider(
|
140 |
+
label='Top-k',
|
141 |
+
minimum=1,
|
142 |
+
maximum=1000,
|
143 |
+
step=1,
|
144 |
+
value=50,
|
145 |
+
)
|
146 |
|
147 |
+
gr.Examples(
|
148 |
+
examples=[
|
149 |
+
'你好'
|
150 |
+
],
|
151 |
+
inputs=textbox,
|
152 |
+
outputs=[textbox, chatbot],
|
153 |
+
fn=process_example,
|
154 |
+
cache_examples=True,
|
155 |
+
)
|
156 |
|
157 |
+
gr.Markdown(LICENSE)
|
158 |
+
|
159 |
+
textbox.submit(
|
160 |
+
fn=clear_and_save_textbox,
|
161 |
+
inputs=textbox,
|
162 |
+
outputs=[textbox, saved_input],
|
163 |
+
api_name=False,
|
164 |
+
queue=False,
|
165 |
+
).then(
|
166 |
+
fn=display_input,
|
167 |
+
inputs=[saved_input, chatbot],
|
168 |
+
outputs=chatbot,
|
169 |
+
api_name=False,
|
170 |
+
queue=False,
|
171 |
+
).then(
|
172 |
+
fn=check_input_token_length,
|
173 |
+
inputs=[saved_input, chatbot, system_prompt],
|
174 |
+
api_name=False,
|
175 |
+
queue=False,
|
176 |
+
).success(
|
177 |
+
fn=generate,
|
178 |
+
inputs=[
|
179 |
+
saved_input,
|
180 |
chatbot,
|
181 |
+
system_prompt,
|
182 |
+
max_new_tokens,
|
183 |
temperature,
|
184 |
+
top_p,
|
185 |
+
top_k,
|
186 |
],
|
187 |
+
outputs=chatbot,
|
188 |
+
api_name=False,
|
189 |
)
|
190 |
+
|
191 |
+
button_event_preprocess = submit_button.click(
|
192 |
+
fn=clear_and_save_textbox,
|
193 |
+
inputs=textbox,
|
194 |
+
outputs=[textbox, saved_input],
|
195 |
+
api_name=False,
|
196 |
+
queue=False,
|
197 |
+
).then(
|
198 |
+
fn=display_input,
|
199 |
+
inputs=[saved_input, chatbot],
|
200 |
+
outputs=chatbot,
|
201 |
+
api_name=False,
|
202 |
+
queue=False,
|
203 |
+
).then(
|
204 |
+
fn=check_input_token_length,
|
205 |
+
inputs=[saved_input, chatbot, system_prompt],
|
206 |
+
api_name=False,
|
207 |
+
queue=False,
|
208 |
+
).success(
|
209 |
+
fn=generate,
|
210 |
+
inputs=[
|
211 |
+
saved_input,
|
212 |
chatbot,
|
213 |
+
system_prompt,
|
214 |
+
max_new_tokens,
|
215 |
temperature,
|
216 |
+
top_p,
|
217 |
+
top_k,
|
218 |
],
|
219 |
+
outputs=chatbot,
|
220 |
+
api_name=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
)
|
222 |
|
223 |
+
retry_button.click(
|
224 |
+
fn=delete_prev_fn,
|
225 |
+
inputs=chatbot,
|
226 |
+
outputs=[chatbot, saved_input],
|
227 |
+
api_name=False,
|
228 |
+
queue=False,
|
229 |
+
).then(
|
230 |
+
fn=display_input,
|
231 |
+
inputs=[saved_input, chatbot],
|
232 |
+
outputs=chatbot,
|
233 |
+
api_name=False,
|
234 |
+
queue=False,
|
235 |
+
).then(
|
236 |
+
fn=generate,
|
237 |
inputs=[
|
238 |
+
saved_input,
|
239 |
chatbot,
|
240 |
+
system_prompt,
|
241 |
+
max_new_tokens,
|
242 |
temperature,
|
243 |
+
top_p,
|
244 |
+
top_k,
|
245 |
],
|
246 |
+
outputs=chatbot,
|
247 |
+
api_name=False,
|
248 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
+
undo_button.click(
|
251 |
+
fn=delete_prev_fn,
|
252 |
+
inputs=chatbot,
|
253 |
+
outputs=[chatbot, saved_input],
|
254 |
+
api_name=False,
|
255 |
+
queue=False,
|
256 |
+
).then(
|
257 |
+
fn=lambda x: x,
|
258 |
+
inputs=[saved_input],
|
259 |
+
outputs=textbox,
|
260 |
+
api_name=False,
|
261 |
+
queue=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
)
|
|
|
263 |
|
264 |
+
clear_button.click(
|
265 |
+
fn=lambda: ([], ''),
|
266 |
+
outputs=[chatbot, saved_input],
|
267 |
+
queue=False,
|
268 |
+
api_name=False,
|
269 |
+
)
|
|
|
270 |
|
271 |
+
demo.queue(max_size=20).launch()
|
model.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
from typing import Iterator
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
+
|
7 |
+
model_id = '../merge'
|
8 |
+
|
9 |
+
if torch.cuda.is_available():
|
10 |
+
config = AutoConfig.from_pretrained(model_id)
|
11 |
+
config.pretraining_tp = 1
|
12 |
+
model = AutoModelForCausalLM.from_pretrained(
|
13 |
+
model_id,
|
14 |
+
config=config,
|
15 |
+
torch_dtype=torch.float16,
|
16 |
+
load_in_4bit=True,
|
17 |
+
device_map='auto'
|
18 |
+
)
|
19 |
+
else:
|
20 |
+
model = None
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
22 |
+
|
23 |
+
|
24 |
+
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
25 |
+
system_prompt: str) -> str:
|
26 |
+
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
27 |
+
# The first user input is _not_ stripped
|
28 |
+
do_strip = False
|
29 |
+
for user_input, response in chat_history:
|
30 |
+
user_input = user_input.strip() if do_strip else user_input
|
31 |
+
do_strip = True
|
32 |
+
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
33 |
+
message = message.strip() if do_strip else message
|
34 |
+
texts.append(f'{message} [/INST]')
|
35 |
+
return ''.join(texts)
|
36 |
+
|
37 |
+
|
38 |
+
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
|
39 |
+
prompt = get_prompt(message, chat_history, system_prompt)
|
40 |
+
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
41 |
+
return input_ids.shape[-1]
|
42 |
+
|
43 |
+
|
44 |
+
def run(message: str,
|
45 |
+
chat_history: list[tuple[str, str]],
|
46 |
+
system_prompt: str,
|
47 |
+
max_new_tokens: int = 1024,
|
48 |
+
temperature: float = 0.8,
|
49 |
+
top_p: float = 0.95,
|
50 |
+
top_k: int = 50) -> Iterator[str]:
|
51 |
+
prompt = get_prompt(message, chat_history, system_prompt)
|
52 |
+
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
|
53 |
+
|
54 |
+
streamer = TextIteratorStreamer(tokenizer,
|
55 |
+
timeout=10.,
|
56 |
+
skip_prompt=True,
|
57 |
+
skip_special_tokens=True)
|
58 |
+
generate_kwargs = dict(
|
59 |
+
inputs,
|
60 |
+
streamer=streamer,
|
61 |
+
max_new_tokens=max_new_tokens,
|
62 |
+
do_sample=True,
|
63 |
+
top_p=top_p,
|
64 |
+
top_k=top_k,
|
65 |
+
temperature=temperature,
|
66 |
+
num_beams=1,
|
67 |
+
)
|
68 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
69 |
+
t.start()
|
70 |
+
|
71 |
+
outputs = []
|
72 |
+
for text in streamer:
|
73 |
+
outputs.append(text)
|
74 |
+
yield ''.join(outputs)
|
requirements.txt
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
loguru
|
|
|
1 |
+
accelerate==0.21.0
|
2 |
+
bitsandbytes==0.40.2
|
3 |
+
gradio==3.37.0
|
4 |
+
protobuf==3.20.3
|
5 |
+
scipy==1.11.1
|
6 |
+
sentencepiece==0.1.99
|
7 |
+
torch==2.0.1
|
8 |
+
transformers==4.31.0
|
|
style.css
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
|
5 |
+
#duplicate-button {
|
6 |
+
margin: auto;
|
7 |
+
color: white;
|
8 |
+
background: #1565c0;
|
9 |
+
border-radius: 100vh;
|
10 |
+
}
|
11 |
+
|
12 |
+
#component-0 {
|
13 |
+
max-width: 900px;
|
14 |
+
margin: auto;
|
15 |
+
padding-top: 1.5rem;
|
16 |
+
}
|