Spaces:
Running
on
Zero
Running
on
Zero
testing
Browse files- app.py +546 -77
- serve/builder.py +113 -0
- serve/constants.py +7 -0
- serve/controller.py +277 -0
- serve/conversation.py +234 -0
- serve/examples/example_1.png +0 -0
- serve/examples/example_2.png +0 -0
- serve/examples/icon.jpg +0 -0
- serve/examples/user.png +0 -0
- serve/gradio_web_server.py +496 -0
- serve/mm_utils.py +107 -0
- serve/model_worker.py +290 -0
- serve/utils.py +124 -0
app.py
CHANGED
@@ -1,86 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
|
4 |
-
|
5 |
-
import
|
6 |
-
import
|
7 |
-
from PIL import Image
|
8 |
-
import torch
|
9 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import subprocess
|
11 |
-
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
12 |
|
13 |
-
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
'qnguyen3/nanoLLaVA',
|
17 |
-
trust_remote_code=True)
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
device_map='auto',
|
23 |
-
trust_remote_code=True)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
@spaces.GPU
|
27 |
-
def
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
import gradio as gr
|
7 |
+
import requests
|
8 |
+
import hashlib
|
9 |
+
import pypandoc
|
10 |
+
import base64
|
11 |
+
import sys
|
|
|
|
|
12 |
import spaces
|
13 |
+
|
14 |
+
from io import BytesIO
|
15 |
+
|
16 |
+
from serve.conversation import (default_conversation, conv_templates, SeparatorStyle)
|
17 |
+
from serve.constants import LOGDIR
|
18 |
+
from serve.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
|
19 |
import subprocess
|
|
|
20 |
|
21 |
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
|
22 |
+
|
23 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
24 |
|
25 |
+
headers = {"User-Agent": "Bunny Client"}
|
|
|
|
|
26 |
|
27 |
+
no_change_btn = gr.update()
|
28 |
+
enable_btn = gr.update(interactive=True)
|
29 |
+
disable_btn = gr.update(interactive=False)
|
|
|
|
|
30 |
|
31 |
+
priority = {
|
32 |
+
"Bunny": "aaaaaaa",
|
33 |
+
}
|
34 |
+
|
35 |
+
def start_controller():
|
36 |
+
print("Starting the controller")
|
37 |
+
controller_command = [
|
38 |
+
sys.executable,
|
39 |
+
"serve/controller.py",
|
40 |
+
"--host",
|
41 |
+
"0.0.0.0",
|
42 |
+
"--port",
|
43 |
+
"10000",
|
44 |
+
]
|
45 |
+
print(controller_command)
|
46 |
+
return subprocess.Popen(controller_command)
|
47 |
|
48 |
@spaces.GPU
|
49 |
+
def start_worker(model_path: str):
|
50 |
+
print(f"Starting the model worker for the model {model_path}")
|
51 |
+
model_path = 'qnguyen3/nanoLLaVA'
|
52 |
+
worker_command = [
|
53 |
+
sys.executable,
|
54 |
+
"serve/model_worker.py",
|
55 |
+
"--host",
|
56 |
+
"0.0.0.0",
|
57 |
+
"--controller",
|
58 |
+
"http://localhost:10000",
|
59 |
+
"--port",
|
60 |
+
"40000",
|
61 |
+
"worker",
|
62 |
+
"http://localhost:40000":
|
63 |
+
"--model-path",
|
64 |
+
model_path,
|
65 |
+
"--model-type",
|
66 |
+
"qwen1.5-0.5b",
|
67 |
+
"--use-flash-attn",
|
68 |
+
]
|
69 |
+
print(worker_command)
|
70 |
+
return subprocess.Popen(worker_command)
|
71 |
+
|
72 |
+
|
73 |
+
def get_conv_log_filename():
|
74 |
+
t = datetime.datetime.now()
|
75 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
76 |
+
return name
|
77 |
+
|
78 |
+
|
79 |
+
def get_model_list():
|
80 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
81 |
+
assert ret.status_code == 200
|
82 |
+
ret = requests.post(args.controller_url + "/list_models")
|
83 |
+
models = ret.json()["models"]
|
84 |
+
models.sort(key=lambda x: priority.get(x, x))
|
85 |
+
logger.info(f"Models: {models}")
|
86 |
+
return models
|
87 |
+
|
88 |
+
|
89 |
+
get_window_url_params = """
|
90 |
+
function() {
|
91 |
+
const params = new URLSearchParams(window.location.search);
|
92 |
+
url_params = Object.fromEntries(params);
|
93 |
+
console.log(url_params);
|
94 |
+
return url_params;
|
95 |
+
}
|
96 |
+
"""
|
97 |
+
|
98 |
+
|
99 |
+
def load_demo(url_params, request: gr.Request):
|
100 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
101 |
+
|
102 |
+
dropdown_update = gr.update(visible=True)
|
103 |
+
if "model" in url_params:
|
104 |
+
model = url_params["model"]
|
105 |
+
if model in models:
|
106 |
+
dropdown_update = gr.update(
|
107 |
+
value=model, visible=True)
|
108 |
+
|
109 |
+
state = default_conversation.copy()
|
110 |
+
return state, dropdown_update
|
111 |
+
|
112 |
+
|
113 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
114 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
115 |
+
models = get_model_list()
|
116 |
+
state = default_conversation.copy()
|
117 |
+
dropdown_update = gr.update(
|
118 |
+
choices=models,
|
119 |
+
value=models[0] if len(models) > 0 else ""
|
120 |
+
)
|
121 |
+
return state, dropdown_update
|
122 |
+
|
123 |
+
|
124 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
125 |
+
with open(get_conv_log_filename(), "a") as fout:
|
126 |
+
data = {
|
127 |
+
"tstamp": round(time.time(), 4),
|
128 |
+
"type": vote_type,
|
129 |
+
"model": model_selector,
|
130 |
+
"state": state.dict(),
|
131 |
+
"ip": request.client.host,
|
132 |
+
}
|
133 |
+
fout.write(json.dumps(data) + "\n")
|
134 |
+
|
135 |
+
|
136 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
137 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
138 |
+
vote_last_response(state, "upvote", model_selector, request)
|
139 |
+
return ("",) + (disable_btn,) * 3
|
140 |
+
|
141 |
+
|
142 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
143 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
144 |
+
vote_last_response(state, "downvote", model_selector, request)
|
145 |
+
return ("",) + (disable_btn,) * 3
|
146 |
+
|
147 |
+
|
148 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
149 |
+
logger.info(f"flag. ip: {request.client.host}")
|
150 |
+
vote_last_response(state, "flag", model_selector, request)
|
151 |
+
return ("",) + (disable_btn,) * 3
|
152 |
+
|
153 |
+
|
154 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
155 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
156 |
+
state.messages[-1][-1] = None
|
157 |
+
prev_human_msg = state.messages[-2]
|
158 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
159 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
160 |
+
state.skip_next = False
|
161 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
162 |
+
|
163 |
+
|
164 |
+
def clear_history(request: gr.Request):
|
165 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
166 |
+
state = default_conversation.copy()
|
167 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
168 |
+
|
169 |
+
|
170 |
+
def save_conversation(conversation):
|
171 |
+
print("save_conversation_wrapper is called")
|
172 |
+
html_content = "<html><body>"
|
173 |
+
|
174 |
+
for role, message in conversation.messages:
|
175 |
+
if isinstance(message, str): # only text
|
176 |
+
html_content += f"<p><b>{role}</b>: {message}</p>"
|
177 |
+
elif isinstance(message, tuple): # text+image
|
178 |
+
text, image_obj, _ = message
|
179 |
+
|
180 |
+
# add text
|
181 |
+
if text:
|
182 |
+
html_content += f"<p><b>{role}</b>: {text}</p>"
|
183 |
+
|
184 |
+
# add image
|
185 |
+
buffered = BytesIO()
|
186 |
+
image_obj.save(buffered, format="PNG")
|
187 |
+
encoded_image = base64.b64encode(buffered.getvalue()).decode()
|
188 |
+
html_content += f'<img src="data:image/png;base64,{encoded_image}" /><br>'
|
189 |
+
|
190 |
+
html_content += "</body></html>"
|
191 |
+
|
192 |
+
doc_path = "./conversation.docx"
|
193 |
+
pypandoc.convert_text(html_content, 'docx', format='html', outputfile=doc_path,
|
194 |
+
extra_args=["-M2GB", "+RTS", "-K64m", "-RTS"])
|
195 |
+
return doc_path
|
196 |
+
|
197 |
+
|
198 |
+
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
199 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
200 |
+
if len(text) <= 0 and image is None:
|
201 |
+
state.skip_next = True
|
202 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
203 |
+
if args.moderate:
|
204 |
+
flagged = violates_moderation(text)
|
205 |
+
if flagged:
|
206 |
+
state.skip_next = True
|
207 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
208 |
+
no_change_btn,) * 5
|
209 |
+
|
210 |
+
text = text[:1536] # Hard cut-off
|
211 |
+
if image is not None:
|
212 |
+
text = text[:1200] # Hard cut-off for images
|
213 |
+
if '<image>' not in text:
|
214 |
+
# text = '<Image><image></Image>' + text
|
215 |
+
text = text + '\n<image>'
|
216 |
+
text = (text, image, image_process_mode)
|
217 |
+
if len(state.get_images(return_pil=True)) > 0:
|
218 |
+
state = default_conversation.copy()
|
219 |
+
logger.info(f"Input Text: {text}")
|
220 |
+
state.append_message(state.roles[0], text)
|
221 |
+
state.append_message(state.roles[1], None)
|
222 |
+
state.skip_next = False
|
223 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
224 |
+
|
225 |
+
|
226 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
227 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
228 |
+
start_tstamp = time.time()
|
229 |
+
model_name = model_selector
|
230 |
+
|
231 |
+
if state.skip_next:
|
232 |
+
# This generate call is skipped due to invalid inputs
|
233 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
234 |
+
return
|
235 |
+
|
236 |
+
if len(state.messages) == state.offset + 2:
|
237 |
+
template_name = "bunny"
|
238 |
+
new_state = conv_templates[template_name].copy()
|
239 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
240 |
+
new_state.append_message(new_state.roles[1], None)
|
241 |
+
state = new_state
|
242 |
+
|
243 |
+
logger.info(f"Processed Input Text: {state.messages[-2][1]}")
|
244 |
+
# Query worker address
|
245 |
+
controller_url = args.controller_url
|
246 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
247 |
+
json={"model": model_name})
|
248 |
+
worker_addr = ret.json()["address"]
|
249 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
250 |
+
|
251 |
+
# No available worker
|
252 |
+
if worker_addr == "":
|
253 |
+
state.messages[-1][-1] = server_error_msg
|
254 |
+
yield (state, state.to_gradio_chatbot(), enable_btn, enable_btn, enable_btn)
|
255 |
+
return
|
256 |
+
|
257 |
+
# Construct prompt
|
258 |
+
prompt = state.get_prompt()
|
259 |
+
|
260 |
+
all_images = state.get_images(return_pil=True)
|
261 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
262 |
+
for image, hash in zip(all_images, all_image_hash):
|
263 |
+
t = datetime.datetime.now()
|
264 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
265 |
+
if not os.path.isfile(filename):
|
266 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
267 |
+
image.save(filename)
|
268 |
+
|
269 |
+
# Make requests
|
270 |
+
pload = {
|
271 |
+
"model": model_name,
|
272 |
+
"prompt": prompt,
|
273 |
+
"temperature": float(temperature),
|
274 |
+
"top_p": float(top_p),
|
275 |
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
276 |
+
"stop": '<|im_end|>', #state.sep if state.sep_style in [SeparatorStyle.PLAIN, ] else state.sep2,
|
277 |
+
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
278 |
+
}
|
279 |
+
logger.info(f"==== request ====\n{pload}")
|
280 |
+
|
281 |
+
pload['images'] = state.get_images()
|
282 |
+
print('=========> get_images')
|
283 |
+
state.messages[-1][-1] = "▌"
|
284 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
285 |
+
print('=========> state', state.messages[-1][-1])
|
286 |
+
|
287 |
+
try:
|
288 |
+
# Stream output
|
289 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
290 |
+
headers=headers, json=pload, stream=True, timeout=1000)
|
291 |
+
print("====> response ok")
|
292 |
+
print("====> response dir", dir(response))
|
293 |
+
print("====> response", response)
|
294 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
295 |
+
if chunk:
|
296 |
+
data = json.loads(chunk.decode())
|
297 |
+
if data["error_code"] == 0:
|
298 |
+
output = data["text"][len(prompt):].strip()
|
299 |
+
state.messages[-1][-1] = output + "▌"
|
300 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
301 |
+
else:
|
302 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
303 |
+
state.messages[-1][-1] = output
|
304 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
|
305 |
+
return
|
306 |
+
time.sleep(0.03)
|
307 |
+
except requests.exceptions.RequestException as e:
|
308 |
+
state.messages[-1][-1] = server_error_msg
|
309 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
|
310 |
+
return
|
311 |
+
|
312 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
313 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
314 |
+
|
315 |
+
finish_tstamp = time.time()
|
316 |
+
logger.info(f"{output}")
|
317 |
+
|
318 |
+
with open(get_conv_log_filename(), "a") as fout:
|
319 |
+
data = {
|
320 |
+
"tstamp": round(finish_tstamp, 4),
|
321 |
+
"type": "chat",
|
322 |
+
"model": model_name,
|
323 |
+
"start": round(start_tstamp, 4),
|
324 |
+
"finish": round(finish_tstamp, 4),
|
325 |
+
"state": state.dict(),
|
326 |
+
"images": all_image_hash,
|
327 |
+
"ip": request.client.host,
|
328 |
+
}
|
329 |
+
fout.write(json.dumps(data) + "\n")
|
330 |
+
|
331 |
+
|
332 |
+
title_markdown = ("""
|
333 |
+
# 🐰 Bunny: A family of lightweight multimodal models
|
334 |
+
|
335 |
+
[📖[Technical report](https://arxiv.org/abs/2402.11530)] | [🏠[Code](https://github.com/BAAI-DCAI/Bunny)] | [🤗[Model](https://huggingface.co/BAAI/Bunny-v1_0-3B)]
|
336 |
+
|
337 |
+
""")
|
338 |
+
|
339 |
+
tos_markdown = ("""
|
340 |
+
### Terms of use
|
341 |
+
By using this service, users are required to agree to the following terms:
|
342 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
343 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
344 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
345 |
+
""")
|
346 |
+
|
347 |
+
learn_more_markdown = ("""
|
348 |
+
### License
|
349 |
+
This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses. The content of this project itself is licensed under the Apache license 2.0.
|
350 |
+
""")
|
351 |
+
|
352 |
+
block_css = """
|
353 |
+
.centered {
|
354 |
+
text-align: center;
|
355 |
+
}
|
356 |
+
#buttons button {
|
357 |
+
min-width: min(120px,100%);
|
358 |
+
}
|
359 |
+
#file-downloader {
|
360 |
+
min-width: min(120px,100%);
|
361 |
+
height: 50px;
|
362 |
+
}
|
363 |
+
"""
|
364 |
+
|
365 |
+
|
366 |
+
def trigger_download(doc_path):
|
367 |
+
return doc_path
|
368 |
+
|
369 |
+
|
370 |
+
def build_demo(embed_mode):
|
371 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
372 |
+
with gr.Blocks(title="Bunny", theme=gr.themes.Default(primary_hue="blue", secondary_hue="lime"),
|
373 |
+
css=block_css) as demo:
|
374 |
+
state = gr.State()
|
375 |
+
|
376 |
+
if not embed_mode:
|
377 |
+
gr.Markdown(title_markdown)
|
378 |
+
|
379 |
+
with gr.Row():
|
380 |
+
with gr.Column(scale=4):
|
381 |
+
with gr.Row(elem_id="model_selector_row"):
|
382 |
+
model_selector = gr.Dropdown(
|
383 |
+
choices=models,
|
384 |
+
value=models[0] if len(models) > 0 else "",
|
385 |
+
interactive=True,
|
386 |
+
show_label=False,
|
387 |
+
container=False,
|
388 |
+
allow_custom_value=True
|
389 |
+
)
|
390 |
+
|
391 |
+
imagebox = gr.Image(type="pil")
|
392 |
+
image_process_mode = gr.Radio(
|
393 |
+
["Crop", "Resize", "Pad", "Default"],
|
394 |
+
value="Default",
|
395 |
+
label="Preprocess for non-square image", visible=False)
|
396 |
+
|
397 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
398 |
+
gr.Examples(examples=[
|
399 |
+
[f"{cur_dir}/examples/example_1.png", "What is the astronaut holding in his hand?"],
|
400 |
+
[f"{cur_dir}/examples/example_2.png", "Why is the image funny?"],
|
401 |
+
], inputs=[imagebox, textbox])
|
402 |
+
|
403 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
404 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
|
405 |
+
label="Temperature", )
|
406 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P", )
|
407 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
|
408 |
+
label="Max output tokens", )
|
409 |
+
|
410 |
+
file_output = gr.components.File(label="Download Document", visible=True, elem_id="file-downloader")
|
411 |
+
with gr.Column(scale=8):
|
412 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Bunny Chatbot",
|
413 |
+
avatar_images=[f"{cur_dir}/examples/user.png", f"{cur_dir}/examples/icon.jpg"],
|
414 |
+
height=550)
|
415 |
+
with gr.Row():
|
416 |
+
with gr.Column(scale=8):
|
417 |
+
textbox.render()
|
418 |
+
with gr.Column(scale=1, min_width=50):
|
419 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
420 |
+
|
421 |
+
with gr.Row(elem_id="buttons") as button_row:
|
422 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
423 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
424 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
425 |
+
regenerate_btn = gr.Button(value="🔁 Regenerate", interactive=False)
|
426 |
+
clear_btn = gr.Button(value="🚮 Clear", interactive=False)
|
427 |
+
save_conversation_btn = gr.Button(value="🗃️ Save", interactive=False)
|
428 |
+
|
429 |
+
if not embed_mode:
|
430 |
+
gr.Markdown(tos_markdown)
|
431 |
+
gr.Markdown(learn_more_markdown)
|
432 |
+
url_params = gr.JSON(visible=False)
|
433 |
+
|
434 |
+
# Register listeners
|
435 |
+
btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn, save_conversation_btn]
|
436 |
+
|
437 |
+
upvote_btn.click(
|
438 |
+
upvote_last_response,
|
439 |
+
[state, model_selector],
|
440 |
+
[textbox, upvote_btn, downvote_btn]
|
441 |
+
)
|
442 |
+
downvote_btn.click(
|
443 |
+
downvote_last_response,
|
444 |
+
[state, model_selector],
|
445 |
+
[textbox, upvote_btn, downvote_btn]
|
446 |
+
)
|
447 |
+
|
448 |
+
regenerate_btn.click(
|
449 |
+
regenerate,
|
450 |
+
[state, image_process_mode],
|
451 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
452 |
+
queue=False
|
453 |
+
).then(
|
454 |
+
http_bot,
|
455 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
456 |
+
[state, chatbot] + btn_list
|
457 |
+
)
|
458 |
+
|
459 |
+
clear_btn.click(
|
460 |
+
clear_history,
|
461 |
+
None,
|
462 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
463 |
+
queue=False
|
464 |
+
)
|
465 |
+
|
466 |
+
save_conversation_btn.click(
|
467 |
+
save_conversation,
|
468 |
+
inputs=[state],
|
469 |
+
outputs=file_output
|
470 |
+
)
|
471 |
+
|
472 |
+
textbox.submit(
|
473 |
+
add_text,
|
474 |
+
[state, textbox, imagebox, image_process_mode],
|
475 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
476 |
+
queue=False
|
477 |
+
).then(
|
478 |
+
http_bot,
|
479 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
480 |
+
[state, chatbot] + btn_list
|
481 |
+
)
|
482 |
+
|
483 |
+
submit_btn.click(
|
484 |
+
add_text,
|
485 |
+
[state, textbox, imagebox, image_process_mode],
|
486 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
487 |
+
queue=False
|
488 |
+
).then(
|
489 |
+
http_bot,
|
490 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
491 |
+
[state, chatbot] + btn_list
|
492 |
+
)
|
493 |
+
|
494 |
+
if args.model_list_mode == "once":
|
495 |
+
demo.load(
|
496 |
+
load_demo,
|
497 |
+
[url_params],
|
498 |
+
[state, model_selector],
|
499 |
+
_js=get_window_url_params,
|
500 |
+
queue=False
|
501 |
+
)
|
502 |
+
elif args.model_list_mode == "reload":
|
503 |
+
demo.load(
|
504 |
+
load_demo_refresh_model_list,
|
505 |
+
None,
|
506 |
+
[state, model_selector],
|
507 |
+
queue=False
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
511 |
+
|
512 |
+
return demo
|
513 |
+
|
514 |
+
|
515 |
+
if __name__ == "__main__":
|
516 |
+
parser = argparse.ArgumentParser()
|
517 |
+
parser.add_argument("--host", type=str, default="127.0.0.1")
|
518 |
+
parser.add_argument("--port", type=int)
|
519 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
520 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
521 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
522 |
+
choices=["once", "reload"])
|
523 |
+
parser.add_argument("--share", action="store_true")
|
524 |
+
parser.add_argument("--moderate", action="store_true")
|
525 |
+
parser.add_argument("--embed", action="store_true")
|
526 |
+
args = parser.parse_args()
|
527 |
+
logger.info(f"args: {args}")
|
528 |
+
|
529 |
+
models = get_model_list()
|
530 |
+
logger.info(args)
|
531 |
+
|
532 |
+
model_path = os.getenv("model", "liuhaotian/llava-v1.6-mistral-7b")
|
533 |
+
concurrency_count = int(os.getenv("concurrency_count", 5))
|
534 |
+
|
535 |
+
controller_proc = start_controller()
|
536 |
+
model_path = 'qnguyen3/nanoLLaVA'
|
537 |
+
worker_proc = start_worker(model_path)
|
538 |
+
time.sleep(10)
|
539 |
+
exit_status = 0
|
540 |
+
try:
|
541 |
+
demo = build_demo(args.embed)
|
542 |
+
demo.launch(
|
543 |
+
server_name=args.host,
|
544 |
+
server_port=args.port,
|
545 |
+
share=args.share,
|
546 |
+
debug=True,
|
547 |
+
max_threads=10
|
548 |
+
)
|
549 |
+
except Exception as e:
|
550 |
+
print(e)
|
551 |
+
exit_status = 1
|
552 |
+
finally:
|
553 |
+
worker_proc.kill()
|
554 |
+
controller_proc.kill()
|
555 |
+
sys.exit(exit_status)
|
serve/builder.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig, logging, AutoModelForCausalLM
|
6 |
+
|
7 |
+
logging.set_verbosity_error()
|
8 |
+
|
9 |
+
def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False,
|
10 |
+
device_map="auto", device="cuda", **kwargs):
|
11 |
+
if model_type not in {'qwen1.5-1.8b', 'qwen1.5-0.5b'}:
|
12 |
+
raise ValueError(f"Unknown Model Type {model_type}")
|
13 |
+
|
14 |
+
kwargs = {"device_map": device_map, **kwargs}
|
15 |
+
|
16 |
+
if device != "cuda":
|
17 |
+
kwargs['device_map'] = {"": device}
|
18 |
+
|
19 |
+
if load_8bit:
|
20 |
+
kwargs['load_in_8bit'] = True
|
21 |
+
elif load_4bit:
|
22 |
+
kwargs['load_in_4bit'] = True
|
23 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
24 |
+
load_in_4bit=True,
|
25 |
+
bnb_4bit_compute_dtype=torch.float16,
|
26 |
+
bnb_4bit_use_double_quant=True,
|
27 |
+
bnb_4bit_quant_type='nf4'
|
28 |
+
)
|
29 |
+
else:
|
30 |
+
kwargs['torch_dtype'] = torch.float16
|
31 |
+
|
32 |
+
if 'lora' in model_name.lower() and model_base is None:
|
33 |
+
warnings.warn(
|
34 |
+
'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
|
35 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
36 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
37 |
+
|
38 |
+
print('Loading nanoLLaVA from base model...')
|
39 |
+
if model_type == 'qwen1.5-1.8b' or model_type == 'qwen1.5-0.5b':
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
41 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained,
|
42 |
+
**kwargs)
|
43 |
+
|
44 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
45 |
+
if model.lm_head.weight.shape[0] != token_num:
|
46 |
+
model.lm_head.weight = torch.nn.Parameter(
|
47 |
+
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
48 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(
|
49 |
+
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
50 |
+
|
51 |
+
print('Loading additional nanoLLaVA weights...')
|
52 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
53 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
54 |
+
else:
|
55 |
+
# this is probably from HF Hub
|
56 |
+
from huggingface_hub import hf_hub_download
|
57 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
58 |
+
cache_file = hf_hub_download(
|
59 |
+
repo_id=repo_id,
|
60 |
+
filename=filename,
|
61 |
+
subfolder=subfolder)
|
62 |
+
return torch.load(cache_file, map_location='cpu')
|
63 |
+
|
64 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
65 |
+
|
66 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
|
67 |
+
non_lora_trainables.items()}
|
68 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
69 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in
|
70 |
+
non_lora_trainables.items()}
|
71 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
72 |
+
|
73 |
+
from peft import PeftModel
|
74 |
+
print('Loading LoRA weights...')
|
75 |
+
model = PeftModel.from_pretrained(model, model_path)
|
76 |
+
print('Merging LoRA weights...')
|
77 |
+
model = model.merge_and_unload()
|
78 |
+
print('Model is loaded...')
|
79 |
+
elif model_base is not None:
|
80 |
+
# this may be mm projector only
|
81 |
+
print('Loading nanoLLaVA from base model...')
|
82 |
+
|
83 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
84 |
+
if model_type == 'qwen1.5-1.8b' or model_type == 'qwen1.5-0.5b':
|
85 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
86 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
|
87 |
+
**kwargs)
|
88 |
+
|
89 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
90 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
91 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
92 |
+
else:
|
93 |
+
if model_type == 'qwen1.5-1.8b' or model_type == 'qwen1.5-0.5b':
|
94 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
95 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
96 |
+
|
97 |
+
model.resize_token_embeddings(len(tokenizer))
|
98 |
+
|
99 |
+
vision_tower = model.get_vision_tower()
|
100 |
+
if not vision_tower.is_loaded:
|
101 |
+
vision_tower.load_model()
|
102 |
+
vision_tower.to(device=device, dtype=torch.float16)
|
103 |
+
image_processor = vision_tower.image_processor
|
104 |
+
|
105 |
+
if hasattr(model.config, "max_sequence_length"):
|
106 |
+
context_len = model.config.max_sequence_length
|
107 |
+
else:
|
108 |
+
context_len = 2048
|
109 |
+
|
110 |
+
if model.generation_config.pad_token_id is None:
|
111 |
+
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
112 |
+
|
113 |
+
return tokenizer, model, image_processor, context_len
|
serve/constants.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Constants
|
2 |
+
IGNORE_INDEX = -100
|
3 |
+
IMAGE_TOKEN_INDEX = -200
|
4 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
5 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
6 |
+
LOGDIR = "gradio-logs"
|
7 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
serve/controller.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import dataclasses
|
7 |
+
import threading
|
8 |
+
import json
|
9 |
+
import time
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
import uvicorn
|
13 |
+
|
14 |
+
from typing import List
|
15 |
+
from enum import Enum, auto
|
16 |
+
from fastapi import FastAPI, Request
|
17 |
+
from fastapi.responses import StreamingResponse
|
18 |
+
|
19 |
+
from .constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
20 |
+
from .utils import build_logger, server_error_msg
|
21 |
+
|
22 |
+
logger = build_logger("controller", "controller.log")
|
23 |
+
|
24 |
+
|
25 |
+
class DispatchMethod(Enum):
|
26 |
+
LOTTERY = auto()
|
27 |
+
SHORTEST_QUEUE = auto()
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_str(cls, name):
|
31 |
+
if name == "lottery":
|
32 |
+
return cls.LOTTERY
|
33 |
+
elif name == "shortest_queue":
|
34 |
+
return cls.SHORTEST_QUEUE
|
35 |
+
else:
|
36 |
+
raise ValueError(f"Invalid dispatch method")
|
37 |
+
|
38 |
+
|
39 |
+
@dataclasses.dataclass
|
40 |
+
class WorkerInfo:
|
41 |
+
model_names: List[str]
|
42 |
+
speed: int
|
43 |
+
queue_length: int
|
44 |
+
check_heart_beat: bool
|
45 |
+
last_heart_beat: str
|
46 |
+
|
47 |
+
|
48 |
+
def heart_beat_controller(controller):
|
49 |
+
while True:
|
50 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
51 |
+
controller.remove_stable_workers_by_expiration()
|
52 |
+
|
53 |
+
|
54 |
+
class Controller:
|
55 |
+
def __init__(self, dispatch_method: str):
|
56 |
+
# Dict[str -> WorkerInfo]
|
57 |
+
self.worker_info = {}
|
58 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
59 |
+
|
60 |
+
self.heart_beat_thread = threading.Thread(
|
61 |
+
target=heart_beat_controller, args=(self,))
|
62 |
+
self.heart_beat_thread.start()
|
63 |
+
|
64 |
+
logger.info("Init controller")
|
65 |
+
|
66 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
67 |
+
worker_status: dict):
|
68 |
+
if worker_name not in self.worker_info:
|
69 |
+
logger.info(f"Register a new worker: {worker_name}")
|
70 |
+
else:
|
71 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
72 |
+
|
73 |
+
if not worker_status:
|
74 |
+
worker_status = self.get_worker_status(worker_name)
|
75 |
+
if not worker_status:
|
76 |
+
return False
|
77 |
+
|
78 |
+
self.worker_info[worker_name] = WorkerInfo(
|
79 |
+
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
80 |
+
check_heart_beat, time.time())
|
81 |
+
|
82 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
83 |
+
return True
|
84 |
+
|
85 |
+
def get_worker_status(self, worker_name: str):
|
86 |
+
try:
|
87 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
88 |
+
except requests.exceptions.RequestException as e:
|
89 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
90 |
+
return None
|
91 |
+
|
92 |
+
if r.status_code != 200:
|
93 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
94 |
+
return None
|
95 |
+
|
96 |
+
return r.json()
|
97 |
+
|
98 |
+
def remove_worker(self, worker_name: str):
|
99 |
+
del self.worker_info[worker_name]
|
100 |
+
|
101 |
+
def refresh_all_workers(self):
|
102 |
+
old_info = dict(self.worker_info)
|
103 |
+
self.worker_info = {}
|
104 |
+
|
105 |
+
for w_name, w_info in old_info.items():
|
106 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
107 |
+
logger.info(f"Remove stale worker: {w_name}")
|
108 |
+
|
109 |
+
def list_models(self):
|
110 |
+
model_names = set()
|
111 |
+
|
112 |
+
for w_name, w_info in self.worker_info.items():
|
113 |
+
model_names.update(w_info.model_names)
|
114 |
+
|
115 |
+
return list(model_names)
|
116 |
+
|
117 |
+
def get_worker_address(self, model_name: str):
|
118 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
119 |
+
worker_names = []
|
120 |
+
worker_speeds = []
|
121 |
+
for w_name, w_info in self.worker_info.items():
|
122 |
+
if model_name in w_info.model_names:
|
123 |
+
worker_names.append(w_name)
|
124 |
+
worker_speeds.append(w_info.speed)
|
125 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
126 |
+
norm = np.sum(worker_speeds)
|
127 |
+
if norm < 1e-4:
|
128 |
+
return ""
|
129 |
+
worker_speeds = worker_speeds / norm
|
130 |
+
|
131 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
132 |
+
worker_name = worker_names[pt]
|
133 |
+
return worker_name
|
134 |
+
|
135 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
136 |
+
worker_names = []
|
137 |
+
worker_qlen = []
|
138 |
+
for w_name, w_info in self.worker_info.items():
|
139 |
+
if model_name in w_info.model_names:
|
140 |
+
worker_names.append(w_name)
|
141 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
142 |
+
if len(worker_names) == 0:
|
143 |
+
return ""
|
144 |
+
min_index = np.argmin(worker_qlen)
|
145 |
+
w_name = worker_names[min_index]
|
146 |
+
self.worker_info[w_name].queue_length += 1
|
147 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
148 |
+
return w_name
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
151 |
+
|
152 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
153 |
+
if worker_name not in self.worker_info:
|
154 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
155 |
+
return False
|
156 |
+
|
157 |
+
self.worker_info[worker_name].queue_length = queue_length
|
158 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
159 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
160 |
+
return True
|
161 |
+
|
162 |
+
def remove_stable_workers_by_expiration(self):
|
163 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
164 |
+
to_delete = []
|
165 |
+
for worker_name, w_info in self.worker_info.items():
|
166 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
167 |
+
to_delete.append(worker_name)
|
168 |
+
|
169 |
+
for worker_name in to_delete:
|
170 |
+
self.remove_worker(worker_name)
|
171 |
+
|
172 |
+
def worker_api_generate_stream(self, params):
|
173 |
+
worker_addr = self.get_worker_address(params["model"])
|
174 |
+
if not worker_addr:
|
175 |
+
logger.info(f"no worker: {params['model']}")
|
176 |
+
ret = {
|
177 |
+
"text": server_error_msg,
|
178 |
+
"error_code": 2,
|
179 |
+
}
|
180 |
+
yield json.dumps(ret).encode() + b"\0"
|
181 |
+
|
182 |
+
try:
|
183 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
184 |
+
json=params, stream=True, timeout=5)
|
185 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
186 |
+
if chunk:
|
187 |
+
yield chunk + b"\0"
|
188 |
+
except requests.exceptions.RequestException as e:
|
189 |
+
logger.info(f"worker timeout: {worker_addr}")
|
190 |
+
ret = {
|
191 |
+
"text": server_error_msg,
|
192 |
+
"error_code": 3,
|
193 |
+
}
|
194 |
+
yield json.dumps(ret).encode() + b"\0"
|
195 |
+
|
196 |
+
# Let the controller act as a worker to achieve hierarchical
|
197 |
+
# management. This can be used to connect isolated sub networks.
|
198 |
+
def worker_api_get_status(self):
|
199 |
+
model_names = set()
|
200 |
+
speed = 0
|
201 |
+
queue_length = 0
|
202 |
+
|
203 |
+
for w_name in self.worker_info:
|
204 |
+
worker_status = self.get_worker_status(w_name)
|
205 |
+
if worker_status is not None:
|
206 |
+
model_names.update(worker_status["model_names"])
|
207 |
+
speed += worker_status["speed"]
|
208 |
+
queue_length += worker_status["queue_length"]
|
209 |
+
|
210 |
+
return {
|
211 |
+
"model_names": list(model_names),
|
212 |
+
"speed": speed,
|
213 |
+
"queue_length": queue_length,
|
214 |
+
}
|
215 |
+
|
216 |
+
|
217 |
+
app = FastAPI()
|
218 |
+
|
219 |
+
|
220 |
+
@app.post("/register_worker")
|
221 |
+
async def register_worker(request: Request):
|
222 |
+
data = await request.json()
|
223 |
+
controller.register_worker(
|
224 |
+
data["worker_name"], data["check_heart_beat"],
|
225 |
+
data.get("worker_status", None))
|
226 |
+
|
227 |
+
|
228 |
+
@app.post("/refresh_all_workers")
|
229 |
+
async def refresh_all_workers():
|
230 |
+
models = controller.refresh_all_workers()
|
231 |
+
|
232 |
+
|
233 |
+
@app.post("/list_models")
|
234 |
+
async def list_models():
|
235 |
+
models = controller.list_models()
|
236 |
+
return {"models": models}
|
237 |
+
|
238 |
+
|
239 |
+
@app.post("/get_worker_address")
|
240 |
+
async def get_worker_address(request: Request):
|
241 |
+
data = await request.json()
|
242 |
+
addr = controller.get_worker_address(data["model"])
|
243 |
+
return {"address": addr}
|
244 |
+
|
245 |
+
|
246 |
+
@app.post("/receive_heart_beat")
|
247 |
+
async def receive_heart_beat(request: Request):
|
248 |
+
data = await request.json()
|
249 |
+
exist = controller.receive_heart_beat(
|
250 |
+
data["worker_name"], data["queue_length"])
|
251 |
+
return {"exist": exist}
|
252 |
+
|
253 |
+
|
254 |
+
@app.post("/worker_generate_stream")
|
255 |
+
async def worker_api_generate_stream(request: Request):
|
256 |
+
params = await request.json()
|
257 |
+
generator = controller.worker_api_generate_stream(params)
|
258 |
+
return StreamingResponse(generator)
|
259 |
+
|
260 |
+
|
261 |
+
@app.post("/worker_get_status")
|
262 |
+
async def worker_api_get_status(request: Request):
|
263 |
+
return controller.worker_api_get_status()
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == "__main__":
|
267 |
+
parser = argparse.ArgumentParser()
|
268 |
+
parser.add_argument("--host", type=str, default="localhost")
|
269 |
+
parser.add_argument("--port", type=int, default=21001)
|
270 |
+
parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue")
|
271 |
+
args = parser.parse_args()
|
272 |
+
logger.info(f"args: {args}")
|
273 |
+
|
274 |
+
controller = Controller(args.dispatch_method)
|
275 |
+
log_config = uvicorn.config.LOGGING_CONFIG
|
276 |
+
log_config['handlers']['default']['stream'] = 'ext://sys.stdout'
|
277 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
serve/conversation.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
TWO = auto()
|
9 |
+
PLAIN = auto()
|
10 |
+
MPT = auto()
|
11 |
+
|
12 |
+
|
13 |
+
@dataclasses.dataclass
|
14 |
+
class Conversation:
|
15 |
+
"""A class that keeps all conversation history."""
|
16 |
+
system: str
|
17 |
+
roles: List[str]
|
18 |
+
messages: List[List[str]]
|
19 |
+
offset: int
|
20 |
+
sep_style: SeparatorStyle
|
21 |
+
sep: str = "###"
|
22 |
+
sep2: str = None
|
23 |
+
version: str = "Unknown"
|
24 |
+
|
25 |
+
skip_next: bool = False
|
26 |
+
|
27 |
+
def get_prompt(self):
|
28 |
+
messages = self.messages
|
29 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
30 |
+
messages = self.messages.copy()
|
31 |
+
init_role, init_msg = messages[0].copy()
|
32 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
33 |
+
if 'mmtag' in self.version:
|
34 |
+
messages[0] = (init_role, init_msg)
|
35 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
36 |
+
messages.insert(1, (self.roles[1], "Received."))
|
37 |
+
else:
|
38 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
39 |
+
|
40 |
+
if self.sep_style == SeparatorStyle.TWO:
|
41 |
+
seps = [self.sep, self.sep2]
|
42 |
+
ret = self.system + seps[0]
|
43 |
+
for i, (role, message) in enumerate(messages):
|
44 |
+
if message:
|
45 |
+
if type(message) is tuple:
|
46 |
+
message, _, _ = message
|
47 |
+
ret += role + ": " + message + seps[i % 2]
|
48 |
+
else:
|
49 |
+
ret += role + ":"
|
50 |
+
|
51 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
52 |
+
ret = self.system + self.sep
|
53 |
+
for role, message in messages:
|
54 |
+
if message:
|
55 |
+
if type(message) is tuple:
|
56 |
+
message, _, _ = message
|
57 |
+
ret += role + message + self.sep
|
58 |
+
else:
|
59 |
+
ret += role
|
60 |
+
|
61 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
62 |
+
seps = [self.sep, self.sep2]
|
63 |
+
ret = self.system
|
64 |
+
for i, (role, message) in enumerate(messages):
|
65 |
+
if message:
|
66 |
+
if type(message) is tuple:
|
67 |
+
message, _, _ = message
|
68 |
+
ret += message + seps[i % 2]
|
69 |
+
else:
|
70 |
+
ret += ""
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
73 |
+
|
74 |
+
return ret
|
75 |
+
|
76 |
+
def append_message(self, role, message):
|
77 |
+
self.messages.append([role, message])
|
78 |
+
|
79 |
+
def get_images(self, return_pil=False):
|
80 |
+
images = []
|
81 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
82 |
+
if i % 2 == 0:
|
83 |
+
if type(msg) is tuple:
|
84 |
+
import base64
|
85 |
+
from io import BytesIO
|
86 |
+
from PIL import Image
|
87 |
+
msg, image, image_process_mode = msg
|
88 |
+
if image_process_mode == "Pad":
|
89 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
90 |
+
width, height = pil_img.size
|
91 |
+
if width == height:
|
92 |
+
return pil_img
|
93 |
+
elif width > height:
|
94 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
95 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
96 |
+
return result
|
97 |
+
else:
|
98 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
99 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
100 |
+
return result
|
101 |
+
|
102 |
+
image = expand2square(image)
|
103 |
+
elif image_process_mode in ["Default", "Crop"]:
|
104 |
+
pass
|
105 |
+
elif image_process_mode == "Resize":
|
106 |
+
image = image.resize((336, 336))
|
107 |
+
else:
|
108 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
109 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
110 |
+
aspect_ratio = max_hw / min_hw
|
111 |
+
max_len, min_len = 800, 400
|
112 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
113 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
114 |
+
W, H = image.size
|
115 |
+
if longest_edge != max(image.size):
|
116 |
+
if H > W:
|
117 |
+
H, W = longest_edge, shortest_edge
|
118 |
+
else:
|
119 |
+
H, W = shortest_edge, longest_edge
|
120 |
+
image = image.resize((W, H))
|
121 |
+
if return_pil:
|
122 |
+
images.append(image)
|
123 |
+
else:
|
124 |
+
buffered = BytesIO()
|
125 |
+
image.save(buffered, format="PNG")
|
126 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
127 |
+
images.append(img_b64_str)
|
128 |
+
return images
|
129 |
+
|
130 |
+
def to_gradio_chatbot(self):
|
131 |
+
ret = []
|
132 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
133 |
+
if i % 2 == 0:
|
134 |
+
if type(msg) is tuple:
|
135 |
+
import base64
|
136 |
+
from io import BytesIO
|
137 |
+
msg, image, image_process_mode = msg
|
138 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
139 |
+
aspect_ratio = max_hw / min_hw
|
140 |
+
max_len, min_len = 800, 400
|
141 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
142 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
143 |
+
W, H = image.size
|
144 |
+
if H > W:
|
145 |
+
H, W = longest_edge, shortest_edge
|
146 |
+
else:
|
147 |
+
H, W = shortest_edge, longest_edge
|
148 |
+
image = image.resize((W, H))
|
149 |
+
buffered = BytesIO()
|
150 |
+
image.save(buffered, format="JPEG")
|
151 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
152 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
153 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
154 |
+
ret.append([msg, None])
|
155 |
+
else:
|
156 |
+
ret.append([msg, None])
|
157 |
+
else:
|
158 |
+
ret[-1][-1] = msg
|
159 |
+
return ret
|
160 |
+
|
161 |
+
def copy(self):
|
162 |
+
return Conversation(
|
163 |
+
system=self.system,
|
164 |
+
roles=self.roles,
|
165 |
+
messages=[[x, y] for x, y in self.messages],
|
166 |
+
offset=self.offset,
|
167 |
+
sep_style=self.sep_style,
|
168 |
+
sep=self.sep,
|
169 |
+
sep2=self.sep2,
|
170 |
+
version=self.version)
|
171 |
+
|
172 |
+
def dict(self):
|
173 |
+
if len(self.get_images()) > 0:
|
174 |
+
return {
|
175 |
+
"system": self.system,
|
176 |
+
"roles": self.roles,
|
177 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
178 |
+
"offset": self.offset,
|
179 |
+
"sep": self.sep,
|
180 |
+
"sep2": self.sep2,
|
181 |
+
}
|
182 |
+
return {
|
183 |
+
"system": self.system,
|
184 |
+
"roles": self.roles,
|
185 |
+
"messages": self.messages,
|
186 |
+
"offset": self.offset,
|
187 |
+
"sep": self.sep,
|
188 |
+
"sep2": self.sep2,
|
189 |
+
}
|
190 |
+
|
191 |
+
|
192 |
+
conv_bunny = Conversation(
|
193 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
194 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
195 |
+
roles=("USER", "ASSISTANT"),
|
196 |
+
version="bunny",
|
197 |
+
messages=(),
|
198 |
+
offset=0,
|
199 |
+
sep_style=SeparatorStyle.TWO,
|
200 |
+
sep=" ",
|
201 |
+
sep2="<|endoftext|>",
|
202 |
+
)
|
203 |
+
|
204 |
+
conv_plain = Conversation(
|
205 |
+
system="",
|
206 |
+
roles=("", ""),
|
207 |
+
messages=(
|
208 |
+
),
|
209 |
+
offset=0,
|
210 |
+
sep_style=SeparatorStyle.PLAIN,
|
211 |
+
sep="\n",
|
212 |
+
)
|
213 |
+
|
214 |
+
conv_chatml_direct = Conversation(
|
215 |
+
system="""<|im_start|>system
|
216 |
+
Answer the questions.""",
|
217 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
218 |
+
version="mpt",
|
219 |
+
messages=(),
|
220 |
+
offset=0,
|
221 |
+
sep_style=SeparatorStyle.MPT,
|
222 |
+
sep="<|im_end|>",
|
223 |
+
)
|
224 |
+
|
225 |
+
default_conversation = conv_bunny
|
226 |
+
conv_templates = {
|
227 |
+
"default": conv_bunny,
|
228 |
+
"bunny": conv_bunny,
|
229 |
+
"plain": conv_plain,
|
230 |
+
"chatml_direct": conv_chatml_direct,
|
231 |
+
}
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
print(default_conversation.get_prompt())
|
serve/examples/example_1.png
ADDED
serve/examples/example_2.png
ADDED
serve/examples/icon.jpg
ADDED
serve/examples/user.png
ADDED
serve/gradio_web_server.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
import hashlib
|
9 |
+
import pypandoc
|
10 |
+
import base64
|
11 |
+
|
12 |
+
from io import BytesIO
|
13 |
+
|
14 |
+
from .conversation import (default_conversation, conv_templates, SeparatorStyle)
|
15 |
+
from .constants import LOGDIR
|
16 |
+
from .utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
|
17 |
+
|
18 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
19 |
+
|
20 |
+
headers = {"User-Agent": "Bunny Client"}
|
21 |
+
|
22 |
+
no_change_btn = gr.update()
|
23 |
+
enable_btn = gr.update(interactive=True)
|
24 |
+
disable_btn = gr.update(interactive=False)
|
25 |
+
|
26 |
+
priority = {
|
27 |
+
"Bunny": "aaaaaaa",
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def get_conv_log_filename():
|
32 |
+
t = datetime.datetime.now()
|
33 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
34 |
+
return name
|
35 |
+
|
36 |
+
|
37 |
+
def get_model_list():
|
38 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
39 |
+
assert ret.status_code == 200
|
40 |
+
ret = requests.post(args.controller_url + "/list_models")
|
41 |
+
models = ret.json()["models"]
|
42 |
+
models.sort(key=lambda x: priority.get(x, x))
|
43 |
+
logger.info(f"Models: {models}")
|
44 |
+
return models
|
45 |
+
|
46 |
+
|
47 |
+
get_window_url_params = """
|
48 |
+
function() {
|
49 |
+
const params = new URLSearchParams(window.location.search);
|
50 |
+
url_params = Object.fromEntries(params);
|
51 |
+
console.log(url_params);
|
52 |
+
return url_params;
|
53 |
+
}
|
54 |
+
"""
|
55 |
+
|
56 |
+
|
57 |
+
def load_demo(url_params, request: gr.Request):
|
58 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
59 |
+
|
60 |
+
dropdown_update = gr.update(visible=True)
|
61 |
+
if "model" in url_params:
|
62 |
+
model = url_params["model"]
|
63 |
+
if model in models:
|
64 |
+
dropdown_update = gr.update(
|
65 |
+
value=model, visible=True)
|
66 |
+
|
67 |
+
state = default_conversation.copy()
|
68 |
+
return state, dropdown_update
|
69 |
+
|
70 |
+
|
71 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
72 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
73 |
+
models = get_model_list()
|
74 |
+
state = default_conversation.copy()
|
75 |
+
dropdown_update = gr.update(
|
76 |
+
choices=models,
|
77 |
+
value=models[0] if len(models) > 0 else ""
|
78 |
+
)
|
79 |
+
return state, dropdown_update
|
80 |
+
|
81 |
+
|
82 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
83 |
+
with open(get_conv_log_filename(), "a") as fout:
|
84 |
+
data = {
|
85 |
+
"tstamp": round(time.time(), 4),
|
86 |
+
"type": vote_type,
|
87 |
+
"model": model_selector,
|
88 |
+
"state": state.dict(),
|
89 |
+
"ip": request.client.host,
|
90 |
+
}
|
91 |
+
fout.write(json.dumps(data) + "\n")
|
92 |
+
|
93 |
+
|
94 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
95 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
96 |
+
vote_last_response(state, "upvote", model_selector, request)
|
97 |
+
return ("",) + (disable_btn,) * 3
|
98 |
+
|
99 |
+
|
100 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
101 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
102 |
+
vote_last_response(state, "downvote", model_selector, request)
|
103 |
+
return ("",) + (disable_btn,) * 3
|
104 |
+
|
105 |
+
|
106 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
107 |
+
logger.info(f"flag. ip: {request.client.host}")
|
108 |
+
vote_last_response(state, "flag", model_selector, request)
|
109 |
+
return ("",) + (disable_btn,) * 3
|
110 |
+
|
111 |
+
|
112 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
113 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
114 |
+
state.messages[-1][-1] = None
|
115 |
+
prev_human_msg = state.messages[-2]
|
116 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
117 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
118 |
+
state.skip_next = False
|
119 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
120 |
+
|
121 |
+
|
122 |
+
def clear_history(request: gr.Request):
|
123 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
124 |
+
state = default_conversation.copy()
|
125 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
126 |
+
|
127 |
+
|
128 |
+
def save_conversation(conversation):
|
129 |
+
print("save_conversation_wrapper is called")
|
130 |
+
html_content = "<html><body>"
|
131 |
+
|
132 |
+
for role, message in conversation.messages:
|
133 |
+
if isinstance(message, str): # only text
|
134 |
+
html_content += f"<p><b>{role}</b>: {message}</p>"
|
135 |
+
elif isinstance(message, tuple): # text+image
|
136 |
+
text, image_obj, _ = message
|
137 |
+
|
138 |
+
# add text
|
139 |
+
if text:
|
140 |
+
html_content += f"<p><b>{role}</b>: {text}</p>"
|
141 |
+
|
142 |
+
# add image
|
143 |
+
buffered = BytesIO()
|
144 |
+
image_obj.save(buffered, format="PNG")
|
145 |
+
encoded_image = base64.b64encode(buffered.getvalue()).decode()
|
146 |
+
html_content += f'<img src="data:image/png;base64,{encoded_image}" /><br>'
|
147 |
+
|
148 |
+
html_content += "</body></html>"
|
149 |
+
|
150 |
+
doc_path = "./conversation.docx"
|
151 |
+
pypandoc.convert_text(html_content, 'docx', format='html', outputfile=doc_path,
|
152 |
+
extra_args=["-M2GB", "+RTS", "-K64m", "-RTS"])
|
153 |
+
return doc_path
|
154 |
+
|
155 |
+
|
156 |
+
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
157 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
158 |
+
if len(text) <= 0 and image is None:
|
159 |
+
state.skip_next = True
|
160 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
161 |
+
if args.moderate:
|
162 |
+
flagged = violates_moderation(text)
|
163 |
+
if flagged:
|
164 |
+
state.skip_next = True
|
165 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
166 |
+
no_change_btn,) * 5
|
167 |
+
|
168 |
+
text = text[:1536] # Hard cut-off
|
169 |
+
if image is not None:
|
170 |
+
text = text[:1200] # Hard cut-off for images
|
171 |
+
if '<image>' not in text:
|
172 |
+
# text = '<Image><image></Image>' + text
|
173 |
+
text = text + '\n<image>'
|
174 |
+
text = (text, image, image_process_mode)
|
175 |
+
if len(state.get_images(return_pil=True)) > 0:
|
176 |
+
state = default_conversation.copy()
|
177 |
+
logger.info(f"Input Text: {text}")
|
178 |
+
state.append_message(state.roles[0], text)
|
179 |
+
state.append_message(state.roles[1], None)
|
180 |
+
state.skip_next = False
|
181 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
182 |
+
|
183 |
+
|
184 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
185 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
186 |
+
start_tstamp = time.time()
|
187 |
+
model_name = model_selector
|
188 |
+
|
189 |
+
if state.skip_next:
|
190 |
+
# This generate call is skipped due to invalid inputs
|
191 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
192 |
+
return
|
193 |
+
|
194 |
+
if len(state.messages) == state.offset + 2:
|
195 |
+
template_name = "bunny"
|
196 |
+
new_state = conv_templates[template_name].copy()
|
197 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
198 |
+
new_state.append_message(new_state.roles[1], None)
|
199 |
+
state = new_state
|
200 |
+
|
201 |
+
logger.info(f"Processed Input Text: {state.messages[-2][1]}")
|
202 |
+
# Query worker address
|
203 |
+
controller_url = args.controller_url
|
204 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
205 |
+
json={"model": model_name})
|
206 |
+
worker_addr = ret.json()["address"]
|
207 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
208 |
+
|
209 |
+
# No available worker
|
210 |
+
if worker_addr == "":
|
211 |
+
state.messages[-1][-1] = server_error_msg
|
212 |
+
yield (state, state.to_gradio_chatbot(), enable_btn, enable_btn, enable_btn)
|
213 |
+
return
|
214 |
+
|
215 |
+
# Construct prompt
|
216 |
+
prompt = state.get_prompt()
|
217 |
+
|
218 |
+
all_images = state.get_images(return_pil=True)
|
219 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
220 |
+
for image, hash in zip(all_images, all_image_hash):
|
221 |
+
t = datetime.datetime.now()
|
222 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
223 |
+
if not os.path.isfile(filename):
|
224 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
225 |
+
image.save(filename)
|
226 |
+
|
227 |
+
# Make requests
|
228 |
+
pload = {
|
229 |
+
"model": model_name,
|
230 |
+
"prompt": prompt,
|
231 |
+
"temperature": float(temperature),
|
232 |
+
"top_p": float(top_p),
|
233 |
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
234 |
+
"stop": state.sep if state.sep_style in [SeparatorStyle.PLAIN, ] else state.sep2,
|
235 |
+
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
236 |
+
}
|
237 |
+
logger.info(f"==== request ====\n{pload}")
|
238 |
+
|
239 |
+
pload['images'] = state.get_images()
|
240 |
+
print('=========> get_images')
|
241 |
+
state.messages[-1][-1] = "▌"
|
242 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
243 |
+
print('=========> state', state.messages[-1][-1])
|
244 |
+
|
245 |
+
try:
|
246 |
+
# Stream output
|
247 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
248 |
+
headers=headers, json=pload, stream=True, timeout=1000)
|
249 |
+
print("====> response ok")
|
250 |
+
print("====> response dir", dir(response))
|
251 |
+
print("====> response", response)
|
252 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
253 |
+
if chunk:
|
254 |
+
data = json.loads(chunk.decode())
|
255 |
+
if data["error_code"] == 0:
|
256 |
+
output = data["text"][len(prompt):].strip()
|
257 |
+
state.messages[-1][-1] = output + "▌"
|
258 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
259 |
+
else:
|
260 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
261 |
+
state.messages[-1][-1] = output
|
262 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
|
263 |
+
return
|
264 |
+
time.sleep(0.03)
|
265 |
+
except requests.exceptions.RequestException as e:
|
266 |
+
state.messages[-1][-1] = server_error_msg
|
267 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
|
268 |
+
return
|
269 |
+
|
270 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
271 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
272 |
+
|
273 |
+
finish_tstamp = time.time()
|
274 |
+
logger.info(f"{output}")
|
275 |
+
|
276 |
+
with open(get_conv_log_filename(), "a") as fout:
|
277 |
+
data = {
|
278 |
+
"tstamp": round(finish_tstamp, 4),
|
279 |
+
"type": "chat",
|
280 |
+
"model": model_name,
|
281 |
+
"start": round(start_tstamp, 4),
|
282 |
+
"finish": round(finish_tstamp, 4),
|
283 |
+
"state": state.dict(),
|
284 |
+
"images": all_image_hash,
|
285 |
+
"ip": request.client.host,
|
286 |
+
}
|
287 |
+
fout.write(json.dumps(data) + "\n")
|
288 |
+
|
289 |
+
|
290 |
+
title_markdown = ("""
|
291 |
+
# 🐰 Bunny: A family of lightweight multimodal models
|
292 |
+
|
293 |
+
[📖[Technical report](https://arxiv.org/abs/2402.11530)] | [🏠[Code](https://github.com/BAAI-DCAI/Bunny)] | [🤗[Model](https://huggingface.co/BAAI/Bunny-v1_0-3B)]
|
294 |
+
|
295 |
+
""")
|
296 |
+
|
297 |
+
tos_markdown = ("""
|
298 |
+
### Terms of use
|
299 |
+
By using this service, users are required to agree to the following terms:
|
300 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
301 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
302 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
303 |
+
""")
|
304 |
+
|
305 |
+
learn_more_markdown = ("""
|
306 |
+
### License
|
307 |
+
This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses. The content of this project itself is licensed under the Apache license 2.0.
|
308 |
+
""")
|
309 |
+
|
310 |
+
block_css = """
|
311 |
+
.centered {
|
312 |
+
text-align: center;
|
313 |
+
}
|
314 |
+
#buttons button {
|
315 |
+
min-width: min(120px,100%);
|
316 |
+
}
|
317 |
+
#file-downloader {
|
318 |
+
min-width: min(120px,100%);
|
319 |
+
height: 50px;
|
320 |
+
}
|
321 |
+
"""
|
322 |
+
|
323 |
+
|
324 |
+
def trigger_download(doc_path):
|
325 |
+
return doc_path
|
326 |
+
|
327 |
+
|
328 |
+
def build_demo(embed_mode):
|
329 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
330 |
+
with gr.Blocks(title="Bunny", theme=gr.themes.Default(primary_hue="blue", secondary_hue="lime"),
|
331 |
+
css=block_css) as demo:
|
332 |
+
state = gr.State()
|
333 |
+
|
334 |
+
if not embed_mode:
|
335 |
+
gr.Markdown(title_markdown)
|
336 |
+
|
337 |
+
with gr.Row():
|
338 |
+
with gr.Column(scale=4):
|
339 |
+
with gr.Row(elem_id="model_selector_row"):
|
340 |
+
model_selector = gr.Dropdown(
|
341 |
+
choices=models,
|
342 |
+
value=models[0] if len(models) > 0 else "",
|
343 |
+
interactive=True,
|
344 |
+
show_label=False,
|
345 |
+
container=False,
|
346 |
+
allow_custom_value=True
|
347 |
+
)
|
348 |
+
|
349 |
+
imagebox = gr.Image(type="pil")
|
350 |
+
image_process_mode = gr.Radio(
|
351 |
+
["Crop", "Resize", "Pad", "Default"],
|
352 |
+
value="Default",
|
353 |
+
label="Preprocess for non-square image", visible=False)
|
354 |
+
|
355 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
356 |
+
gr.Examples(examples=[
|
357 |
+
[f"{cur_dir}/examples/example_1.png", "What is the astronaut holding in his hand?"],
|
358 |
+
[f"{cur_dir}/examples/example_2.png", "Why is the image funny?"],
|
359 |
+
], inputs=[imagebox, textbox])
|
360 |
+
|
361 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
362 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
|
363 |
+
label="Temperature", )
|
364 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P", )
|
365 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
|
366 |
+
label="Max output tokens", )
|
367 |
+
|
368 |
+
file_output = gr.components.File(label="Download Document", visible=True, elem_id="file-downloader")
|
369 |
+
with gr.Column(scale=8):
|
370 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Bunny Chatbot",
|
371 |
+
avatar_images=[f"{cur_dir}/examples/user.png", f"{cur_dir}/examples/icon.jpg"],
|
372 |
+
height=550)
|
373 |
+
with gr.Row():
|
374 |
+
with gr.Column(scale=8):
|
375 |
+
textbox.render()
|
376 |
+
with gr.Column(scale=1, min_width=50):
|
377 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
378 |
+
|
379 |
+
with gr.Row(elem_id="buttons") as button_row:
|
380 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
381 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
382 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
383 |
+
regenerate_btn = gr.Button(value="🔁 Regenerate", interactive=False)
|
384 |
+
clear_btn = gr.Button(value="🚮 Clear", interactive=False)
|
385 |
+
save_conversation_btn = gr.Button(value="🗃️ Save", interactive=False)
|
386 |
+
|
387 |
+
if not embed_mode:
|
388 |
+
gr.Markdown(tos_markdown)
|
389 |
+
gr.Markdown(learn_more_markdown)
|
390 |
+
url_params = gr.JSON(visible=False)
|
391 |
+
|
392 |
+
# Register listeners
|
393 |
+
btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn, save_conversation_btn]
|
394 |
+
|
395 |
+
upvote_btn.click(
|
396 |
+
upvote_last_response,
|
397 |
+
[state, model_selector],
|
398 |
+
[textbox, upvote_btn, downvote_btn]
|
399 |
+
)
|
400 |
+
downvote_btn.click(
|
401 |
+
downvote_last_response,
|
402 |
+
[state, model_selector],
|
403 |
+
[textbox, upvote_btn, downvote_btn]
|
404 |
+
)
|
405 |
+
|
406 |
+
regenerate_btn.click(
|
407 |
+
regenerate,
|
408 |
+
[state, image_process_mode],
|
409 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
410 |
+
queue=False
|
411 |
+
).then(
|
412 |
+
http_bot,
|
413 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
414 |
+
[state, chatbot] + btn_list
|
415 |
+
)
|
416 |
+
|
417 |
+
clear_btn.click(
|
418 |
+
clear_history,
|
419 |
+
None,
|
420 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
421 |
+
queue=False
|
422 |
+
)
|
423 |
+
|
424 |
+
save_conversation_btn.click(
|
425 |
+
save_conversation,
|
426 |
+
inputs=[state],
|
427 |
+
outputs=file_output
|
428 |
+
)
|
429 |
+
|
430 |
+
textbox.submit(
|
431 |
+
add_text,
|
432 |
+
[state, textbox, imagebox, image_process_mode],
|
433 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
434 |
+
queue=False
|
435 |
+
).then(
|
436 |
+
http_bot,
|
437 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
438 |
+
[state, chatbot] + btn_list
|
439 |
+
)
|
440 |
+
|
441 |
+
submit_btn.click(
|
442 |
+
add_text,
|
443 |
+
[state, textbox, imagebox, image_process_mode],
|
444 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
445 |
+
queue=False
|
446 |
+
).then(
|
447 |
+
http_bot,
|
448 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
449 |
+
[state, chatbot] + btn_list
|
450 |
+
)
|
451 |
+
|
452 |
+
if args.model_list_mode == "once":
|
453 |
+
demo.load(
|
454 |
+
load_demo,
|
455 |
+
[url_params],
|
456 |
+
[state, model_selector],
|
457 |
+
_js=get_window_url_params,
|
458 |
+
queue=False
|
459 |
+
)
|
460 |
+
elif args.model_list_mode == "reload":
|
461 |
+
demo.load(
|
462 |
+
load_demo_refresh_model_list,
|
463 |
+
None,
|
464 |
+
[state, model_selector],
|
465 |
+
queue=False
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
469 |
+
|
470 |
+
return demo
|
471 |
+
|
472 |
+
|
473 |
+
if __name__ == "__main__":
|
474 |
+
parser = argparse.ArgumentParser()
|
475 |
+
parser.add_argument("--host", type=str, default="127.0.0.1")
|
476 |
+
parser.add_argument("--port", type=int)
|
477 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
478 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
479 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
480 |
+
choices=["once", "reload"])
|
481 |
+
parser.add_argument("--share", action="store_true")
|
482 |
+
parser.add_argument("--moderate", action="store_true")
|
483 |
+
parser.add_argument("--embed", action="store_true")
|
484 |
+
args = parser.parse_args()
|
485 |
+
logger.info(f"args: {args}")
|
486 |
+
|
487 |
+
models = get_model_list()
|
488 |
+
logger.info(args)
|
489 |
+
demo = build_demo(args.embed)
|
490 |
+
demo.launch(
|
491 |
+
server_name=args.host,
|
492 |
+
server_port=args.port,
|
493 |
+
share=args.share,
|
494 |
+
debug=True,
|
495 |
+
max_threads=10
|
496 |
+
)
|
serve/mm_utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
|
8 |
+
from .constants import IMAGE_TOKEN_INDEX
|
9 |
+
|
10 |
+
|
11 |
+
def load_image_from_base64(image):
|
12 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
13 |
+
|
14 |
+
|
15 |
+
def expand2square(pil_img, background_color):
|
16 |
+
width, height = pil_img.size
|
17 |
+
if width == height:
|
18 |
+
return pil_img
|
19 |
+
elif width > height:
|
20 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
21 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
22 |
+
return result
|
23 |
+
else:
|
24 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
25 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
def process_images(images, image_processor, model_cfg):
|
30 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
31 |
+
new_images = []
|
32 |
+
if image_aspect_ratio == 'pad':
|
33 |
+
for image in images:
|
34 |
+
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
|
35 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
36 |
+
new_images.append(image)
|
37 |
+
else:
|
38 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
39 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
40 |
+
new_images = torch.stack(new_images, dim=0)
|
41 |
+
return new_images
|
42 |
+
|
43 |
+
|
44 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
45 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
46 |
+
|
47 |
+
def insert_separator(X, sep):
|
48 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
49 |
+
|
50 |
+
input_ids = []
|
51 |
+
offset = 0
|
52 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
53 |
+
offset = 1
|
54 |
+
input_ids.append(prompt_chunks[0][0])
|
55 |
+
|
56 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
57 |
+
input_ids.extend(x[offset:])
|
58 |
+
|
59 |
+
if return_tensors is not None:
|
60 |
+
if return_tensors == 'pt':
|
61 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
62 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
63 |
+
return input_ids
|
64 |
+
|
65 |
+
|
66 |
+
def get_model_name_from_path(model_path):
|
67 |
+
model_path = model_path.strip("/")
|
68 |
+
model_paths = model_path.split("/")
|
69 |
+
if model_paths[-1].startswith('checkpoint-'):
|
70 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
71 |
+
else:
|
72 |
+
return model_paths[-1]
|
73 |
+
|
74 |
+
|
75 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
76 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
77 |
+
self.keywords = keywords
|
78 |
+
self.keyword_ids = []
|
79 |
+
self.max_keyword_len = 0
|
80 |
+
for keyword in keywords:
|
81 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
82 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
83 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
84 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
85 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
86 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
87 |
+
self.tokenizer = tokenizer
|
88 |
+
self.start_len = input_ids.shape[1]
|
89 |
+
|
90 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
91 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
92 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
93 |
+
for keyword_id in self.keyword_ids:
|
94 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
|
95 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
96 |
+
return True
|
97 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
98 |
+
for keyword in self.keywords:
|
99 |
+
if keyword in outputs:
|
100 |
+
return True
|
101 |
+
return False
|
102 |
+
|
103 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
104 |
+
outputs = []
|
105 |
+
for i in range(output_ids.shape[0]):
|
106 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
107 |
+
return all(outputs)
|
serve/model_worker.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
import threading
|
6 |
+
import uuid
|
7 |
+
import requests
|
8 |
+
import torch
|
9 |
+
import uvicorn
|
10 |
+
import transformers
|
11 |
+
|
12 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
13 |
+
from fastapi.responses import StreamingResponse
|
14 |
+
from functools import partial
|
15 |
+
from transformers import TextIteratorStreamer
|
16 |
+
from threading import Thread
|
17 |
+
|
18 |
+
from .constants import WORKER_HEART_BEAT_INTERVAL
|
19 |
+
from .utils import (build_logger, server_error_msg, pretty_print_semaphore)
|
20 |
+
from .builder import load_pretrained_model
|
21 |
+
from .mm_utils import process_images, load_image_from_base64, tokenizer_image_token, get_model_name_from_path, \
|
22 |
+
KeywordsStoppingCriteria
|
23 |
+
from .constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
24 |
+
|
25 |
+
GB = 1 << 30
|
26 |
+
|
27 |
+
worker_id = str(uuid.uuid4())[:6]
|
28 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
29 |
+
global_counter = 0
|
30 |
+
|
31 |
+
model_semaphore = None
|
32 |
+
|
33 |
+
|
34 |
+
def heart_beat_worker(controller):
|
35 |
+
while True:
|
36 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
37 |
+
controller.send_heart_beat()
|
38 |
+
|
39 |
+
|
40 |
+
class ModelWorker:
|
41 |
+
def __init__(self, controller_addr, worker_addr,
|
42 |
+
worker_id, no_register,
|
43 |
+
model_path, model_base, model_name, model_type,
|
44 |
+
load_8bit, load_4bit, device):
|
45 |
+
self.controller_addr = controller_addr
|
46 |
+
self.worker_addr = worker_addr
|
47 |
+
self.worker_id = worker_id
|
48 |
+
if model_path.endswith("/"):
|
49 |
+
model_path = model_path[:-1]
|
50 |
+
if model_name is None:
|
51 |
+
self.model_name = get_model_name_from_path(model_path)
|
52 |
+
else:
|
53 |
+
self.model_name = model_name
|
54 |
+
|
55 |
+
self.device = device
|
56 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
57 |
+
transformers.logging.disable_progress_bar()
|
58 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
59 |
+
model_path, model_base, self.model_name, model_type, load_8bit, load_4bit, device=self.device)
|
60 |
+
self.is_multimodal = True
|
61 |
+
|
62 |
+
if not no_register:
|
63 |
+
self.register_to_controller()
|
64 |
+
self.heart_beat_thread = threading.Thread(
|
65 |
+
target=heart_beat_worker, args=(self,))
|
66 |
+
self.heart_beat_thread.start()
|
67 |
+
|
68 |
+
def register_to_controller(self):
|
69 |
+
logger.info("Register to controller")
|
70 |
+
|
71 |
+
url = self.controller_addr + "/register_worker"
|
72 |
+
data = {
|
73 |
+
"worker_name": self.worker_addr,
|
74 |
+
"check_heart_beat": True,
|
75 |
+
"worker_status": self.get_status()
|
76 |
+
}
|
77 |
+
r = requests.post(url, json=data)
|
78 |
+
assert r.status_code == 200
|
79 |
+
|
80 |
+
def send_heart_beat(self):
|
81 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
82 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
83 |
+
f"global_counter: {global_counter}")
|
84 |
+
|
85 |
+
url = self.controller_addr + "/receive_heart_beat"
|
86 |
+
|
87 |
+
while True:
|
88 |
+
try:
|
89 |
+
ret = requests.post(url, json={
|
90 |
+
"worker_name": self.worker_addr,
|
91 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
92 |
+
exist = ret.json()["exist"]
|
93 |
+
break
|
94 |
+
except requests.exceptions.RequestException as e:
|
95 |
+
logger.error(f"heart beat error: {e}")
|
96 |
+
time.sleep(5)
|
97 |
+
|
98 |
+
if not exist:
|
99 |
+
self.register_to_controller()
|
100 |
+
|
101 |
+
def get_queue_length(self):
|
102 |
+
if model_semaphore is None:
|
103 |
+
return 0
|
104 |
+
else:
|
105 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
106 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
107 |
+
|
108 |
+
def get_status(self):
|
109 |
+
return {
|
110 |
+
"model_names": [self.model_name],
|
111 |
+
"speed": 1,
|
112 |
+
"queue_length": self.get_queue_length(),
|
113 |
+
}
|
114 |
+
|
115 |
+
@torch.inference_mode()
|
116 |
+
def generate_stream(self, params):
|
117 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
118 |
+
|
119 |
+
prompt = params["prompt"]
|
120 |
+
ori_prompt = prompt
|
121 |
+
images = params.get("images", None)
|
122 |
+
num_image_tokens = 0
|
123 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
124 |
+
if len(images) > 0:
|
125 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
126 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
127 |
+
|
128 |
+
images = [load_image_from_base64(image) for image in images]
|
129 |
+
images = process_images(images, image_processor, model.config)
|
130 |
+
print(f"----> process_images {images}")
|
131 |
+
print(f"----> process_images sum {torch.sum(images)}")
|
132 |
+
if type(images) is list:
|
133 |
+
images = [image.to(self.model.device, dtype=model.dtype) for image in images]
|
134 |
+
else:
|
135 |
+
images = images.to(self.model.device, dtype=model.dtype)
|
136 |
+
|
137 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
138 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
139 |
+
|
140 |
+
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
|
141 |
+
else:
|
142 |
+
images = None
|
143 |
+
image_args = {"images": images}
|
144 |
+
else:
|
145 |
+
images = None
|
146 |
+
image_args = {}
|
147 |
+
|
148 |
+
temperature = float(params.get("temperature", 1.0))
|
149 |
+
top_p = float(params.get("top_p", 1.0))
|
150 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
|
151 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
152 |
+
stop_str = params.get("stop", None)
|
153 |
+
do_sample = True if temperature > 0.001 else False
|
154 |
+
|
155 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
|
156 |
+
self.device)
|
157 |
+
keywords = [stop_str]
|
158 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
159 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
160 |
+
|
161 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
162 |
+
|
163 |
+
if max_new_tokens < 1:
|
164 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.",
|
165 |
+
"error_code": 0}).encode() + b"\0"
|
166 |
+
return
|
167 |
+
print("max_new_tokens", max_new_tokens)
|
168 |
+
print("start!")
|
169 |
+
|
170 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
171 |
+
inputs=input_ids,
|
172 |
+
do_sample=do_sample,
|
173 |
+
temperature=temperature,
|
174 |
+
top_p=top_p,
|
175 |
+
max_new_tokens=max_new_tokens,
|
176 |
+
streamer=streamer,
|
177 |
+
stopping_criteria=[stopping_criteria],
|
178 |
+
use_cache=True,
|
179 |
+
**image_args
|
180 |
+
))
|
181 |
+
thread.start()
|
182 |
+
|
183 |
+
generated_text = ori_prompt
|
184 |
+
for new_text in streamer:
|
185 |
+
if generated_text and not generated_text.endswith(' '):
|
186 |
+
generated_text += ' '
|
187 |
+
generated_text += new_text
|
188 |
+
if generated_text.endswith(stop_str):
|
189 |
+
generated_text = generated_text[:-len(stop_str)]
|
190 |
+
logger.info(f"new_text: {new_text}")
|
191 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
192 |
+
|
193 |
+
def generate_stream_gate(self, params):
|
194 |
+
try:
|
195 |
+
for x in self.generate_stream(params):
|
196 |
+
yield x
|
197 |
+
except ValueError as e:
|
198 |
+
print("Caught ValueError:", e)
|
199 |
+
ret = {
|
200 |
+
"text": server_error_msg,
|
201 |
+
"error_code": 1,
|
202 |
+
}
|
203 |
+
yield json.dumps(ret).encode() + b"\0"
|
204 |
+
except torch.cuda.CudaError as e:
|
205 |
+
print("Caught torch.cuda.CudaError:", e)
|
206 |
+
ret = {
|
207 |
+
"text": server_error_msg,
|
208 |
+
"error_code": 1,
|
209 |
+
}
|
210 |
+
yield json.dumps(ret).encode() + b"\0"
|
211 |
+
except Exception as e:
|
212 |
+
print("Caught Unknown Error", e)
|
213 |
+
ret = {
|
214 |
+
"text": server_error_msg,
|
215 |
+
"error_code": 1,
|
216 |
+
}
|
217 |
+
yield json.dumps(ret).encode() + b"\0"
|
218 |
+
|
219 |
+
|
220 |
+
app = FastAPI()
|
221 |
+
|
222 |
+
|
223 |
+
def release_model_semaphore(fn=None):
|
224 |
+
model_semaphore.release()
|
225 |
+
if fn is not None:
|
226 |
+
fn()
|
227 |
+
|
228 |
+
|
229 |
+
@app.post("/worker_generate_stream")
|
230 |
+
async def generate_stream(request: Request):
|
231 |
+
global model_semaphore, global_counter
|
232 |
+
global_counter += 1
|
233 |
+
params = await request.json()
|
234 |
+
|
235 |
+
if model_semaphore is None:
|
236 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
237 |
+
await model_semaphore.acquire()
|
238 |
+
worker.send_heart_beat()
|
239 |
+
generator = worker.generate_stream_gate(params)
|
240 |
+
background_tasks = BackgroundTasks()
|
241 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
242 |
+
return StreamingResponse(generator, background=background_tasks)
|
243 |
+
|
244 |
+
|
245 |
+
@app.post("/worker_get_status")
|
246 |
+
async def get_status(request: Request):
|
247 |
+
return worker.get_status()
|
248 |
+
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
parser = argparse.ArgumentParser()
|
252 |
+
parser.add_argument("--host", type=str, default="localhost")
|
253 |
+
parser.add_argument("--port", type=int, default=21002)
|
254 |
+
parser.add_argument("--worker-address", type=str,
|
255 |
+
default="http://localhost:21002")
|
256 |
+
parser.add_argument("--controller-address", type=str,
|
257 |
+
default="http://localhost:21001")
|
258 |
+
parser.add_argument("--model-path", type=str, default=None)
|
259 |
+
parser.add_argument("--model-base", type=str, default=None)
|
260 |
+
parser.add_argument("--model-name", type=str)
|
261 |
+
parser.add_argument("--model-type", type=str, default=None)
|
262 |
+
parser.add_argument("--device", type=str, default="cuda")
|
263 |
+
parser.add_argument("--multi-modal", action="store_true",
|
264 |
+
help="Multimodal mode is automatically detected with model name.")
|
265 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
266 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
267 |
+
parser.add_argument("--no-register", action="store_true")
|
268 |
+
parser.add_argument("--load-8bit", action="store_true")
|
269 |
+
parser.add_argument("--load-4bit", action="store_true")
|
270 |
+
args = parser.parse_args()
|
271 |
+
logger.info(f"args: {args}")
|
272 |
+
|
273 |
+
if args.multi_modal:
|
274 |
+
logger.warning("Multimodal mode is automatically detected with model name.")
|
275 |
+
|
276 |
+
worker = ModelWorker(args.controller_address,
|
277 |
+
args.worker_address,
|
278 |
+
worker_id,
|
279 |
+
args.no_register,
|
280 |
+
args.model_path,
|
281 |
+
args.model_base,
|
282 |
+
args.model_name,
|
283 |
+
args.model_type,
|
284 |
+
args.load_8bit,
|
285 |
+
args.load_4bit,
|
286 |
+
args.device)
|
287 |
+
|
288 |
+
log_config = uvicorn.config.LOGGING_CONFIG
|
289 |
+
log_config['handlers']['default']['stream'] = 'ext://sys.stdout'
|
290 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
serve/utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging.handlers
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from .constants import LOGDIR
|
7 |
+
|
8 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
9 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
10 |
+
|
11 |
+
handler = None
|
12 |
+
|
13 |
+
|
14 |
+
def disable_torch_init():
|
15 |
+
"""
|
16 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
17 |
+
"""
|
18 |
+
import torch
|
19 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
20 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
21 |
+
|
22 |
+
|
23 |
+
def build_logger(logger_name, logger_filename):
|
24 |
+
global handler
|
25 |
+
|
26 |
+
formatter = logging.Formatter(
|
27 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
28 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29 |
+
)
|
30 |
+
|
31 |
+
# Set the format of root handlers
|
32 |
+
if not logging.getLogger().handlers:
|
33 |
+
logging.basicConfig(level=logging.INFO)
|
34 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
35 |
+
|
36 |
+
# Redirect stdout and stderr to loggers
|
37 |
+
stdout_logger = logging.getLogger("stdout")
|
38 |
+
stdout_logger.setLevel(logging.INFO)
|
39 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
40 |
+
sys.stdout = sl
|
41 |
+
|
42 |
+
stderr_logger = logging.getLogger("stderr")
|
43 |
+
stderr_logger.setLevel(logging.ERROR)
|
44 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
45 |
+
sys.stderr = sl
|
46 |
+
|
47 |
+
# Get logger
|
48 |
+
logger = logging.getLogger(logger_name)
|
49 |
+
logger.setLevel(logging.INFO)
|
50 |
+
|
51 |
+
# Add a file handler for all loggers
|
52 |
+
if handler is None:
|
53 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
54 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
55 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
56 |
+
filename, when='D', utc=True, encoding='UTF-8')
|
57 |
+
handler.setFormatter(formatter)
|
58 |
+
|
59 |
+
for name, item in logging.root.manager.loggerDict.items():
|
60 |
+
if isinstance(item, logging.Logger):
|
61 |
+
item.addHandler(handler)
|
62 |
+
|
63 |
+
return logger
|
64 |
+
|
65 |
+
|
66 |
+
class StreamToLogger(object):
|
67 |
+
"""
|
68 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, logger, log_level=logging.INFO):
|
72 |
+
self.terminal = sys.stdout
|
73 |
+
self.logger = logger
|
74 |
+
self.log_level = log_level
|
75 |
+
self.linebuf = ''
|
76 |
+
|
77 |
+
def __getattr__(self, attr):
|
78 |
+
return getattr(self.terminal, attr)
|
79 |
+
|
80 |
+
def write(self, buf):
|
81 |
+
temp_linebuf = self.linebuf + buf
|
82 |
+
self.linebuf = ''
|
83 |
+
for line in temp_linebuf.splitlines(True):
|
84 |
+
# From the io.TextIOWrapper docs:
|
85 |
+
# On output, if newline is None, any '\n' characters written
|
86 |
+
# are translated to the system default line separator.
|
87 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
88 |
+
# translates them so this is still cross platform.
|
89 |
+
if line[-1] == '\n':
|
90 |
+
self.logger.log(self.log_level, line.rstrip())
|
91 |
+
else:
|
92 |
+
self.linebuf += line
|
93 |
+
|
94 |
+
def flush(self):
|
95 |
+
if self.linebuf != '':
|
96 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
97 |
+
self.linebuf = ''
|
98 |
+
|
99 |
+
|
100 |
+
def violates_moderation(text):
|
101 |
+
"""
|
102 |
+
Check whether the text violates OpenAI moderation API.
|
103 |
+
"""
|
104 |
+
url = "https://api.openai.com/v1/moderations"
|
105 |
+
headers = {"Content-Type": "application/json",
|
106 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
107 |
+
text = text.replace("\n", "")
|
108 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
109 |
+
data = data.encode("utf-8")
|
110 |
+
try:
|
111 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
112 |
+
flagged = ret.json()["results"][0]["flagged"]
|
113 |
+
except requests.exceptions.RequestException as e:
|
114 |
+
flagged = False
|
115 |
+
except KeyError as e:
|
116 |
+
flagged = False
|
117 |
+
|
118 |
+
return flagged
|
119 |
+
|
120 |
+
|
121 |
+
def pretty_print_semaphore(semaphore):
|
122 |
+
if semaphore is None:
|
123 |
+
return "None"
|
124 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|