update gradio demo
Browse files- .streamlit/config.toml +0 -7
- README.md +2 -2
- app.py +115 -59
- {static → assets}/SimHei.ttf +0 -0
- assets/assistant.png +0 -0
- assets/human.png +0 -0
- controller.py +3 -1
- conversation.py +259 -0
- gallery/child_1.jpg +0 -0
- gallery/child_2.jpg +0 -0
- gallery/child_3.jpg +0 -0
- gradio_web_server.py +824 -0
- library.py +0 -95
- mm_utils.py +0 -102
- model_worker.py +283 -140
- requirements.txt +14 -4
- utils.py +63 -24
.streamlit/config.toml
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
[server]
|
2 |
-
enableStaticServing = false
|
3 |
-
enableXsrfProtection = false
|
4 |
-
enableCORS = false
|
5 |
-
|
6 |
-
[browser] # This ip and port will show in command prompt
|
7 |
-
enableCORS = false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -3,8 +3,8 @@ title: InternVL
|
|
3 |
emoji: ⚡
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
3 |
emoji: ⚡
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
CHANGED
@@ -1,60 +1,116 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
header {visibility: hidden;}
|
9 |
-
</style>
|
10 |
-
"""
|
11 |
-
|
12 |
-
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
13 |
-
|
14 |
-
st.markdown(
|
15 |
-
"""
|
16 |
-
<style>
|
17 |
-
html, body, .fullScreenFrame, .fullScreenFrame iframe {
|
18 |
-
margin: 0;
|
19 |
-
padding: 0;
|
20 |
-
height: 100%;
|
21 |
-
width: 100%;
|
22 |
-
border: none;
|
23 |
-
display: block;
|
24 |
-
overflow: hidden;
|
25 |
-
}
|
26 |
-
|
27 |
-
.fullScreenFrame {
|
28 |
-
position: fixed;
|
29 |
-
top: 0;
|
30 |
-
left: 0;
|
31 |
-
right: 0;
|
32 |
-
bottom: 0;
|
33 |
-
z-index: 9999;
|
34 |
-
}
|
35 |
-
|
36 |
-
.main .block-container {
|
37 |
-
padding: 0;
|
38 |
-
margin: 0;
|
39 |
-
height: 100vh;
|
40 |
-
}
|
41 |
-
|
42 |
-
/* Hide Streamlit header and footer */
|
43 |
-
header, footer {
|
44 |
-
display: none;
|
45 |
-
}
|
46 |
-
</style>
|
47 |
-
""",
|
48 |
-
unsafe_allow_html=True,
|
49 |
-
)
|
50 |
-
|
51 |
-
# Embed the external Streamlit webpage
|
52 |
-
st.markdown(
|
53 |
-
"""
|
54 |
-
<div class="fullScreenFrame">
|
55 |
-
<iframe src="https://internvl.opengvlab.com/"></iframe>
|
56 |
-
</div>
|
57 |
-
""",
|
58 |
-
unsafe_allow_html=True,
|
59 |
-
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import signal
|
6 |
+
import subprocess
|
7 |
+
import atexit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
|
10 |
+
def kill_processes_by_cmd_substring(cmd_substring):
|
11 |
+
# execute `ps -ef` and obtain its output
|
12 |
+
result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
|
13 |
+
lines = result.stdout.splitlines()
|
14 |
+
|
15 |
+
# visit each line
|
16 |
+
for line in lines:
|
17 |
+
if cmd_substring in line:
|
18 |
+
# extract PID
|
19 |
+
parts = line.split()
|
20 |
+
pid = int(parts[1])
|
21 |
+
print(f"Killing process with PID: {pid}, CMD: {line}")
|
22 |
+
os.kill(pid, signal.SIGTERM)
|
23 |
+
|
24 |
+
|
25 |
+
def main(
|
26 |
+
python_path="python",
|
27 |
+
run_controller=True,
|
28 |
+
run_worker=True,
|
29 |
+
run_gradio=True,
|
30 |
+
controller_port=10086,
|
31 |
+
gradio_port=10087,
|
32 |
+
worker_names=[
|
33 |
+
"OpenGVLab/InternVL2-8B",
|
34 |
+
],
|
35 |
+
run_sd_worker=False,
|
36 |
+
**kwargs,
|
37 |
+
):
|
38 |
+
host = "http://0.0.0.0"
|
39 |
+
controller_process = None
|
40 |
+
if run_controller:
|
41 |
+
# python controller.py --host 0.0.0.0 --port 10086
|
42 |
+
cmd_args = [
|
43 |
+
f"{python_path}",
|
44 |
+
"controller.py",
|
45 |
+
"--host",
|
46 |
+
"0.0.0.0",
|
47 |
+
"--port",
|
48 |
+
f"{controller_port}",
|
49 |
+
]
|
50 |
+
kill_processes_by_cmd_substring(" ".join(cmd_args))
|
51 |
+
print("Launching controller: ", " ".join(cmd_args))
|
52 |
+
controller_process = subprocess.Popen(cmd_args)
|
53 |
+
atexit.register(controller_process.terminate)
|
54 |
+
|
55 |
+
worker_processes = []
|
56 |
+
if run_worker:
|
57 |
+
worker_port = 10088
|
58 |
+
for worker_name in worker_names:
|
59 |
+
cmd_args = [
|
60 |
+
f"{python_path}",
|
61 |
+
"model_worker.py",
|
62 |
+
"--port",
|
63 |
+
f"{worker_port}",
|
64 |
+
"--controller-url",
|
65 |
+
f"{host}:{controller_port}",
|
66 |
+
"--model-path",
|
67 |
+
f"{worker_name}",
|
68 |
+
"--load-8bit",
|
69 |
+
]
|
70 |
+
kill_processes_by_cmd_substring(" ".join(cmd_args))
|
71 |
+
print("Launching worker: ", " ".join(cmd_args))
|
72 |
+
worker_process = subprocess.Popen(cmd_args)
|
73 |
+
worker_processes.append(worker_process)
|
74 |
+
atexit.register(worker_process.terminate)
|
75 |
+
worker_port += 1
|
76 |
+
|
77 |
+
time.sleep(10)
|
78 |
+
gradio_process = None
|
79 |
+
if run_gradio:
|
80 |
+
# python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
|
81 |
+
cmd_args = [
|
82 |
+
f"{python_path}",
|
83 |
+
"gradio_web_server.py",
|
84 |
+
"--port",
|
85 |
+
f"{gradio_port}",
|
86 |
+
"--controller-url",
|
87 |
+
f"{host}:{controller_port}",
|
88 |
+
"--model-list-mode",
|
89 |
+
"reload",
|
90 |
+
]
|
91 |
+
kill_processes_by_cmd_substring(" ".join(cmd_args))
|
92 |
+
print("Launching gradio: ", " ".join(cmd_args))
|
93 |
+
gradio_process = subprocess.Popen(cmd_args)
|
94 |
+
atexit.register(gradio_process.terminate)
|
95 |
+
|
96 |
+
sd_worker_process = None
|
97 |
+
if run_sd_worker:
|
98 |
+
# python model_worker.py --port 10088 --controller-address http://
|
99 |
+
cmd_args = [f"{python_path}", "sd_worker.py"]
|
100 |
+
kill_processes_by_cmd_substring(" ".join(cmd_args))
|
101 |
+
print("Launching sd_worker: ", " ".join(cmd_args))
|
102 |
+
sd_worker_process = subprocess.Popen(cmd_args)
|
103 |
+
atexit.register(sd_worker_process.terminate)
|
104 |
+
|
105 |
+
for worker_process in worker_processes:
|
106 |
+
worker_process.wait()
|
107 |
+
if controller_process:
|
108 |
+
controller_process.wait()
|
109 |
+
if gradio_process:
|
110 |
+
gradio_process.wait()
|
111 |
+
if sd_worker_process:
|
112 |
+
sd_worker_process.wait()
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
fire.Fire(main)
|
{static → assets}/SimHei.ttf
RENAMED
File without changes
|
assets/assistant.png
ADDED
assets/human.png
ADDED
controller.py
CHANGED
@@ -5,9 +5,9 @@ It sends worker addresses to clients.
|
|
5 |
import argparse
|
6 |
import dataclasses
|
7 |
import json
|
|
|
8 |
import threading
|
9 |
import time
|
10 |
-
import re
|
11 |
from enum import Enum, auto
|
12 |
from typing import List
|
13 |
|
@@ -113,6 +113,8 @@ class Controller:
|
|
113 |
model_names.update(w_info.model_names)
|
114 |
|
115 |
def extract_key(s):
|
|
|
|
|
116 |
match = re.match(r'InternVL2-(\d+)B', s)
|
117 |
if match:
|
118 |
return int(match.group(1))
|
|
|
5 |
import argparse
|
6 |
import dataclasses
|
7 |
import json
|
8 |
+
import re
|
9 |
import threading
|
10 |
import time
|
|
|
11 |
from enum import Enum, auto
|
12 |
from typing import List
|
13 |
|
|
|
113 |
model_names.update(w_info.model_names)
|
114 |
|
115 |
def extract_key(s):
|
116 |
+
if 'Pro' in s:
|
117 |
+
return 999
|
118 |
match = re.match(r'InternVL2-(\d+)B', s)
|
119 |
if match:
|
120 |
return int(match.group(1))
|
conversation.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import dataclasses
|
3 |
+
import base64
|
4 |
+
import copy
|
5 |
+
import hashlib
|
6 |
+
import datetime
|
7 |
+
from io import BytesIO
|
8 |
+
from PIL import Image
|
9 |
+
from typing import Any, List, Dict, Union
|
10 |
+
from dataclasses import field
|
11 |
+
|
12 |
+
from utils import LOGDIR
|
13 |
+
|
14 |
+
|
15 |
+
def pil2base64(img: Image.Image) -> str:
|
16 |
+
buffered = BytesIO()
|
17 |
+
img.save(buffered, format="PNG")
|
18 |
+
return base64.b64encode(buffered.getvalue()).decode()
|
19 |
+
|
20 |
+
|
21 |
+
def resize_img(img: Image.Image, max_len: int, min_len: int) -> Image.Image:
|
22 |
+
max_hw, min_hw = max(img.size), min(img.size)
|
23 |
+
aspect_ratio = max_hw / min_hw
|
24 |
+
# max_len, min_len = 800, 400
|
25 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
26 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
27 |
+
W, H = img.size
|
28 |
+
if H > W:
|
29 |
+
H, W = longest_edge, shortest_edge
|
30 |
+
else:
|
31 |
+
H, W = shortest_edge, longest_edge
|
32 |
+
return img.resize((W, H))
|
33 |
+
|
34 |
+
|
35 |
+
@dataclasses.dataclass
|
36 |
+
class Conversation:
|
37 |
+
"""A class that keeps all conversation history."""
|
38 |
+
|
39 |
+
SYSTEM = "system"
|
40 |
+
USER = "user"
|
41 |
+
ASSISTANT = "assistant"
|
42 |
+
|
43 |
+
roles: List[str] = field(
|
44 |
+
default_factory=lambda: [
|
45 |
+
Conversation.SYSTEM,
|
46 |
+
Conversation.USER,
|
47 |
+
Conversation.ASSISTANT,
|
48 |
+
]
|
49 |
+
)
|
50 |
+
mandatory_system_message = "我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
|
51 |
+
system_message: str = "请尽可能详细地回答用户的问题。"
|
52 |
+
messages: List[Dict[str, Any]] = field(default_factory=lambda: [])
|
53 |
+
max_image_limit: int = 4
|
54 |
+
skip_next: bool = False
|
55 |
+
streaming_placeholder: str = "▌"
|
56 |
+
|
57 |
+
def get_system_message(self):
|
58 |
+
return self.mandatory_system_message + "\n\n" + self.system_message
|
59 |
+
|
60 |
+
def set_system_message(self, system_message: str):
|
61 |
+
self.system_message = system_message
|
62 |
+
return self
|
63 |
+
|
64 |
+
def get_prompt(self, inlude_image=False):
|
65 |
+
send_messages = [{"role": "system", "content": self.get_system_message()}]
|
66 |
+
# send_messages = []
|
67 |
+
for message in self.messages:
|
68 |
+
if message["role"] == self.USER:
|
69 |
+
user_message = {
|
70 |
+
"role": self.USER,
|
71 |
+
"content": message["content"],
|
72 |
+
}
|
73 |
+
if inlude_image and "image" in message:
|
74 |
+
user_message["image"] = []
|
75 |
+
for image in message["image"]:
|
76 |
+
user_message["image"].append(pil2base64(image))
|
77 |
+
send_messages.append(user_message)
|
78 |
+
elif message["role"] == self.ASSISTANT:
|
79 |
+
send_messages.append(
|
80 |
+
{"role": self.ASSISTANT, "content": message["content"]}
|
81 |
+
)
|
82 |
+
elif message["role"] == self.SYSTEM:
|
83 |
+
send_messages.append(
|
84 |
+
{
|
85 |
+
"role": self.SYSTEM,
|
86 |
+
"content": message["content"],
|
87 |
+
}
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise ValueError(f"Invalid role: {message['role']}")
|
91 |
+
return send_messages
|
92 |
+
|
93 |
+
def append_message(
|
94 |
+
self,
|
95 |
+
role,
|
96 |
+
content,
|
97 |
+
image_list=None,
|
98 |
+
):
|
99 |
+
self.messages.append(
|
100 |
+
{
|
101 |
+
"role": role,
|
102 |
+
"content": content,
|
103 |
+
"image": [] if image_list is None else image_list,
|
104 |
+
# "filenames": save_filenames,
|
105 |
+
}
|
106 |
+
)
|
107 |
+
|
108 |
+
def get_images(
|
109 |
+
self,
|
110 |
+
return_copy=False,
|
111 |
+
return_base64=False,
|
112 |
+
source: Union[str, None] = None,
|
113 |
+
):
|
114 |
+
assert source in [self.USER, self.ASSISTANT, None], f"Invalid source: {soure}"
|
115 |
+
images = []
|
116 |
+
for i, msg in enumerate(self.messages):
|
117 |
+
if source and msg["role"] != source:
|
118 |
+
continue
|
119 |
+
|
120 |
+
for image in msg.get("image", []):
|
121 |
+
# org_image = [i.copy() for i in image]
|
122 |
+
if return_copy:
|
123 |
+
image = image.copy()
|
124 |
+
|
125 |
+
if return_base64:
|
126 |
+
image = pil2base64(image)
|
127 |
+
|
128 |
+
images.append(image)
|
129 |
+
|
130 |
+
return images
|
131 |
+
|
132 |
+
def to_gradio_chatbot(self):
|
133 |
+
ret = []
|
134 |
+
for i, msg in enumerate(self.messages):
|
135 |
+
if msg["role"] == self.SYSTEM:
|
136 |
+
continue
|
137 |
+
|
138 |
+
alt_str = (
|
139 |
+
"user upload image" if msg["role"] == self.USER else "output image"
|
140 |
+
)
|
141 |
+
image = msg.get("image", [])
|
142 |
+
if not isinstance(image, list):
|
143 |
+
images = [image]
|
144 |
+
else:
|
145 |
+
images = image
|
146 |
+
|
147 |
+
img_str_list = []
|
148 |
+
for i in range(len(images)):
|
149 |
+
image = resize_img(
|
150 |
+
images[i],
|
151 |
+
400,
|
152 |
+
800,
|
153 |
+
)
|
154 |
+
img_b64_str = pil2base64(image)
|
155 |
+
W, H = image.size
|
156 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" style="width: {W}px; max-width:none; max-height:none"></img>'
|
157 |
+
img_str = (
|
158 |
+
f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" />'
|
159 |
+
)
|
160 |
+
img_str_list.append(img_str)
|
161 |
+
|
162 |
+
if msg["role"] == self.USER:
|
163 |
+
msg_str = " ".join(img_str_list) + msg["content"]
|
164 |
+
ret.append([msg_str, None])
|
165 |
+
else:
|
166 |
+
msg_str = msg["content"] + " ".join(img_str_list)
|
167 |
+
ret[-1][-1] = msg_str
|
168 |
+
return ret
|
169 |
+
|
170 |
+
def update_message(self, role, content, image=None, idx=-1):
|
171 |
+
assert len(self.messages) > 0, "No message in the conversation."
|
172 |
+
|
173 |
+
idx = (idx + len(self.messages)) % len(self.messages)
|
174 |
+
|
175 |
+
assert (
|
176 |
+
self.messages[idx]["role"] == role
|
177 |
+
), f"Role mismatch: {role} vs {self.messages[idx]['role']}"
|
178 |
+
|
179 |
+
self.messages[idx]["content"] = content
|
180 |
+
if image is not None:
|
181 |
+
if image not in self.messages[idx]["image"]:
|
182 |
+
self.messages[idx]["image"] = []
|
183 |
+
if not isinstance(image, list):
|
184 |
+
image = [image]
|
185 |
+
self.messages[idx]["image"].extend(image)
|
186 |
+
|
187 |
+
def return_last_message(self):
|
188 |
+
return self.messages[-1]["content"]
|
189 |
+
|
190 |
+
def end_of_current_turn(self):
|
191 |
+
assert len(self.messages) > 0, "No message in the conversation."
|
192 |
+
assert (
|
193 |
+
self.messages[-1]["role"] == self.ASSISTANT
|
194 |
+
), f"It should end with the message from assistant instead of {self.messages[-1]['role']}."
|
195 |
+
|
196 |
+
if self.messages[-1]["content"][-1] != self.streaming_placeholder:
|
197 |
+
return
|
198 |
+
|
199 |
+
self.update_message(self.ASSISTANT, self.messages[-1]["content"][:-1], None)
|
200 |
+
|
201 |
+
def copy(self):
|
202 |
+
return Conversation(
|
203 |
+
mandatory_system_message=self.mandatory_system_message,
|
204 |
+
system_message=self.system_message,
|
205 |
+
roles=copy.deepcopy(self.roles),
|
206 |
+
messages=copy.deepcopy(self.messages),
|
207 |
+
)
|
208 |
+
|
209 |
+
def dict(self):
|
210 |
+
"""
|
211 |
+
all_images = state.get_images()
|
212 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
213 |
+
t = datetime.datetime.now()
|
214 |
+
for image, hash in zip(all_images, all_image_hash):
|
215 |
+
filename = os.path.join(
|
216 |
+
LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
|
217 |
+
)
|
218 |
+
if not os.path.isfile(filename):
|
219 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
220 |
+
image.save(filename)
|
221 |
+
"""
|
222 |
+
messages = []
|
223 |
+
for message in self.messages:
|
224 |
+
images = []
|
225 |
+
for image in message.get("image", []):
|
226 |
+
filename = self.save_image(image)
|
227 |
+
images.append(filename)
|
228 |
+
|
229 |
+
messages.append(
|
230 |
+
{
|
231 |
+
"role": message["role"],
|
232 |
+
"content": message["content"],
|
233 |
+
"image": images,
|
234 |
+
}
|
235 |
+
)
|
236 |
+
if len(images) == 0:
|
237 |
+
messages[-1].pop("image")
|
238 |
+
|
239 |
+
return {
|
240 |
+
"mandatory_system_message": self.mandatory_system_message,
|
241 |
+
"system_message": self.system_message,
|
242 |
+
"roles": self.roles,
|
243 |
+
"messages": messages,
|
244 |
+
}
|
245 |
+
|
246 |
+
def save_image(self, image: Image.Image) -> str:
|
247 |
+
t = datetime.datetime.now()
|
248 |
+
image_hash = hashlib.md5(image.tobytes()).hexdigest()
|
249 |
+
filename = os.path.join(
|
250 |
+
LOGDIR,
|
251 |
+
"serve_images",
|
252 |
+
f"{t.year}-{t.month:02d}-{t.day:02d}",
|
253 |
+
f"{image_hash}.jpg",
|
254 |
+
)
|
255 |
+
if not os.path.isfile(filename):
|
256 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
257 |
+
image.save(filename)
|
258 |
+
|
259 |
+
return filename
|
gallery/child_1.jpg
ADDED
gallery/child_2.jpg
ADDED
gallery/child_3.jpg
ADDED
gradio_web_server.py
ADDED
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from ast import parse
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import hashlib
|
8 |
+
import re
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import requests
|
12 |
+
import random
|
13 |
+
from filelock import FileLock
|
14 |
+
from io import BytesIO
|
15 |
+
from PIL import Image, ImageDraw, ImageFont
|
16 |
+
|
17 |
+
from constants import LOGDIR
|
18 |
+
from utils import (
|
19 |
+
build_logger,
|
20 |
+
server_error_msg,
|
21 |
+
violates_moderation,
|
22 |
+
moderation_msg,
|
23 |
+
load_image_from_base64,
|
24 |
+
get_log_filename,
|
25 |
+
)
|
26 |
+
from conversation import Conversation
|
27 |
+
|
28 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
29 |
+
|
30 |
+
headers = {"User-Agent": "InternVL-Chat Client"}
|
31 |
+
|
32 |
+
no_change_btn = gr.Button()
|
33 |
+
enable_btn = gr.Button(interactive=True)
|
34 |
+
disable_btn = gr.Button(interactive=False)
|
35 |
+
|
36 |
+
|
37 |
+
def write2file(path, content):
|
38 |
+
lock = FileLock(f"{path}.lock")
|
39 |
+
with lock:
|
40 |
+
with open(path, "a") as fout:
|
41 |
+
fout.write(content)
|
42 |
+
|
43 |
+
|
44 |
+
def sort_models(models):
|
45 |
+
def custom_sort_key(model_name):
|
46 |
+
# InternVL-Chat-V1-5 should be the first item
|
47 |
+
if model_name == "InternVL-Chat-V1-5":
|
48 |
+
return (1, model_name) # 1 indicates highest precedence
|
49 |
+
elif model_name.startswith("InternVL-Chat-V1-5-"):
|
50 |
+
return (1, model_name) # 1 indicates highest precedence
|
51 |
+
else:
|
52 |
+
return (0, model_name) # 0 indicates normal order
|
53 |
+
|
54 |
+
models.sort(key=custom_sort_key, reverse=True)
|
55 |
+
try: # We have five InternVL-Chat-V1-5 models, randomly choose one to be the first
|
56 |
+
first_three = models[:4]
|
57 |
+
random.shuffle(first_three)
|
58 |
+
models[:4] = first_three
|
59 |
+
except:
|
60 |
+
pass
|
61 |
+
return models
|
62 |
+
|
63 |
+
|
64 |
+
def get_model_list():
|
65 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
66 |
+
assert ret.status_code == 200
|
67 |
+
ret = requests.post(args.controller_url + "/list_models")
|
68 |
+
models = ret.json()["models"]
|
69 |
+
models = sort_models(models)
|
70 |
+
|
71 |
+
logger.info(f"Models: {models}")
|
72 |
+
return models
|
73 |
+
|
74 |
+
|
75 |
+
get_window_url_params = """
|
76 |
+
function() {
|
77 |
+
const params = new URLSearchParams(window.location.search);
|
78 |
+
url_params = Object.fromEntries(params);
|
79 |
+
console.log(url_params);
|
80 |
+
return url_params;
|
81 |
+
}
|
82 |
+
"""
|
83 |
+
|
84 |
+
|
85 |
+
def init_state(state=None):
|
86 |
+
if state is not None:
|
87 |
+
del state
|
88 |
+
return Conversation()
|
89 |
+
|
90 |
+
|
91 |
+
def find_bounding_boxes(state, response):
|
92 |
+
pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>")
|
93 |
+
matches = pattern.findall(response)
|
94 |
+
results = []
|
95 |
+
for match in matches:
|
96 |
+
results.append((match[0], eval(match[1])))
|
97 |
+
returned_image = None
|
98 |
+
latest_image = state.get_images(source=state.USER)[-1]
|
99 |
+
returned_image = latest_image.copy()
|
100 |
+
width, height = returned_image.size
|
101 |
+
draw = ImageDraw.Draw(returned_image)
|
102 |
+
for result in results:
|
103 |
+
line_width = max(1, int(min(width, height) / 200))
|
104 |
+
random_color = (
|
105 |
+
random.randint(0, 128),
|
106 |
+
random.randint(0, 128),
|
107 |
+
random.randint(0, 128),
|
108 |
+
)
|
109 |
+
category_name, coordinates = result
|
110 |
+
coordinates = [
|
111 |
+
(
|
112 |
+
float(x[0]) / 1000,
|
113 |
+
float(x[1]) / 1000,
|
114 |
+
float(x[2]) / 1000,
|
115 |
+
float(x[3]) / 1000,
|
116 |
+
)
|
117 |
+
for x in coordinates
|
118 |
+
]
|
119 |
+
coordinates = [
|
120 |
+
(
|
121 |
+
int(x[0] * width),
|
122 |
+
int(x[1] * height),
|
123 |
+
int(x[2] * width),
|
124 |
+
int(x[3] * height),
|
125 |
+
)
|
126 |
+
for x in coordinates
|
127 |
+
]
|
128 |
+
for box in coordinates:
|
129 |
+
draw.rectangle(box, outline=random_color, width=line_width)
|
130 |
+
font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2))
|
131 |
+
text_size = font.getbbox(category_name)
|
132 |
+
text_width, text_height = (
|
133 |
+
text_size[2] - text_size[0],
|
134 |
+
text_size[3] - text_size[1],
|
135 |
+
)
|
136 |
+
text_position = (box[0], max(0, box[1] - text_height))
|
137 |
+
draw.rectangle(
|
138 |
+
[
|
139 |
+
text_position,
|
140 |
+
(text_position[0] + text_width, text_position[1] + text_height),
|
141 |
+
],
|
142 |
+
fill=random_color,
|
143 |
+
)
|
144 |
+
draw.text(text_position, category_name, fill="white", font=font)
|
145 |
+
return returned_image if len(matches) > 0 else None
|
146 |
+
|
147 |
+
|
148 |
+
def query_image_generation(response, sd_worker_url, timeout=15):
|
149 |
+
if not sd_worker_url:
|
150 |
+
return None
|
151 |
+
sd_worker_url = f"{sd_worker_url}/generate_image/"
|
152 |
+
pattern = r"```drawing-instruction\n(.*?)\n```"
|
153 |
+
match = re.search(pattern, response, re.DOTALL)
|
154 |
+
if match:
|
155 |
+
payload = {"caption": match.group(1)}
|
156 |
+
print("drawing-instruction:", payload)
|
157 |
+
response = requests.post(sd_worker_url, json=payload, timeout=timeout)
|
158 |
+
response.raise_for_status() # 检查HTTP请求是否成功
|
159 |
+
image = Image.open(BytesIO(response.content))
|
160 |
+
return image
|
161 |
+
else:
|
162 |
+
return None
|
163 |
+
|
164 |
+
|
165 |
+
def load_demo(url_params, request: gr.Request):
|
166 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
167 |
+
|
168 |
+
dropdown_update = gr.Dropdown(visible=True)
|
169 |
+
if "model" in url_params:
|
170 |
+
model = url_params["model"]
|
171 |
+
if model in models:
|
172 |
+
dropdown_update = gr.Dropdown(value=model, visible=True)
|
173 |
+
|
174 |
+
state = init_state()
|
175 |
+
return state, dropdown_update
|
176 |
+
|
177 |
+
|
178 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
179 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
180 |
+
models = get_model_list()
|
181 |
+
state = init_state()
|
182 |
+
dropdown_update = gr.Dropdown(
|
183 |
+
choices=models, value=models[0] if len(models) > 0 else ""
|
184 |
+
)
|
185 |
+
return state, dropdown_update
|
186 |
+
|
187 |
+
|
188 |
+
def vote_last_response(state, liked, model_selector, request: gr.Request):
|
189 |
+
conv_data = {
|
190 |
+
"tstamp": round(time.time(), 4),
|
191 |
+
"like": liked,
|
192 |
+
"model": model_selector,
|
193 |
+
"state": state.dict(),
|
194 |
+
"ip": request.client.host,
|
195 |
+
}
|
196 |
+
write2file(get_log_filename(), json.dumps(conv_data) + "\n")
|
197 |
+
|
198 |
+
|
199 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
200 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
201 |
+
vote_last_response(state, True, model_selector, request)
|
202 |
+
textbox = gr.MultimodalTextbox(value=None, interactive=True)
|
203 |
+
return (textbox,) + (disable_btn,) * 3
|
204 |
+
|
205 |
+
|
206 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
207 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
208 |
+
vote_last_response(state, False, model_selector, request)
|
209 |
+
textbox = gr.MultimodalTextbox(value=None, interactive=True)
|
210 |
+
return (textbox,) + (disable_btn,) * 3
|
211 |
+
|
212 |
+
|
213 |
+
def vote_selected_response(
|
214 |
+
state, model_selector, request: gr.Request, data: gr.LikeData
|
215 |
+
):
|
216 |
+
logger.info(
|
217 |
+
f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}"
|
218 |
+
)
|
219 |
+
conv_data = {
|
220 |
+
"tstamp": round(time.time(), 4),
|
221 |
+
"like": data.liked,
|
222 |
+
"index": data.index,
|
223 |
+
"model": model_selector,
|
224 |
+
"state": state.dict(),
|
225 |
+
"ip": request.client.host,
|
226 |
+
}
|
227 |
+
write2file(get_log_filename(), json.dumps(conv_data) + "\n")
|
228 |
+
return
|
229 |
+
|
230 |
+
|
231 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
232 |
+
logger.info(f"flag. ip: {request.client.host}")
|
233 |
+
vote_last_response(state, "flag", model_selector, request)
|
234 |
+
textbox = gr.MultimodalTextbox(value=None, interactive=True)
|
235 |
+
return (textbox,) + (disable_btn,) * 3
|
236 |
+
|
237 |
+
|
238 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
239 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
240 |
+
# state.messages[-1][-1] = None
|
241 |
+
state.update_message(Conversation.ASSISTANT, None, -1)
|
242 |
+
prev_human_msg = state.messages[-2]
|
243 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
244 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
245 |
+
state.skip_next = False
|
246 |
+
textbox = gr.MultimodalTextbox(value=None, interactive=True)
|
247 |
+
return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
|
248 |
+
|
249 |
+
|
250 |
+
def clear_history(request: gr.Request):
|
251 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
252 |
+
state = init_state()
|
253 |
+
textbox = gr.MultimodalTextbox(value=None, interactive=True)
|
254 |
+
return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
|
255 |
+
|
256 |
+
|
257 |
+
def change_system_prompt(state, system_prompt, request: gr.Request):
|
258 |
+
logger.info(f"Change system prompt. ip: {request.client.host}")
|
259 |
+
state.set_system_message(system_prompt)
|
260 |
+
return state
|
261 |
+
|
262 |
+
|
263 |
+
def add_text(state, message, system_prompt, request: gr.Request):
|
264 |
+
images = message.get("files", [])
|
265 |
+
text = message.get("text", "").strip()
|
266 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
267 |
+
# import pdb; pdb.set_trace()
|
268 |
+
textbox = gr.MultimodalTextbox(value=None, interactive=False)
|
269 |
+
if len(text) <= 0 and len(images) == 0:
|
270 |
+
state.skip_next = True
|
271 |
+
return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
|
272 |
+
if args.moderate:
|
273 |
+
flagged = violates_moderation(text)
|
274 |
+
if flagged:
|
275 |
+
state.skip_next = True
|
276 |
+
textbox = gr.MultimodalTextbox(
|
277 |
+
value={"text": moderation_msg}, interactive=True
|
278 |
+
)
|
279 |
+
return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
|
280 |
+
images = [Image.open(path).convert("RGB") for path in images]
|
281 |
+
|
282 |
+
if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
|
283 |
+
state = init_state(state)
|
284 |
+
state.set_system_message(system_prompt)
|
285 |
+
state.append_message(Conversation.USER, text, images)
|
286 |
+
state.skip_next = False
|
287 |
+
return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
|
288 |
+
|
289 |
+
|
290 |
+
def http_bot(
|
291 |
+
state,
|
292 |
+
model_selector,
|
293 |
+
temperature,
|
294 |
+
top_p,
|
295 |
+
repetition_penalty,
|
296 |
+
max_new_tokens,
|
297 |
+
max_input_tiles,
|
298 |
+
# bbox_threshold,
|
299 |
+
# mask_threshold,
|
300 |
+
request: gr.Request,
|
301 |
+
):
|
302 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
303 |
+
start_tstamp = time.time()
|
304 |
+
model_name = model_selector
|
305 |
+
if hasattr(state, "skip_next") and state.skip_next:
|
306 |
+
# This generate call is skipped due to invalid inputs
|
307 |
+
yield (
|
308 |
+
state,
|
309 |
+
state.to_gradio_chatbot(),
|
310 |
+
gr.MultimodalTextbox(interactive=False),
|
311 |
+
) + (no_change_btn,) * 5
|
312 |
+
return
|
313 |
+
|
314 |
+
# Query worker address
|
315 |
+
controller_url = args.controller_url
|
316 |
+
ret = requests.post(
|
317 |
+
controller_url + "/get_worker_address", json={"model": model_name}
|
318 |
+
)
|
319 |
+
worker_addr = ret.json()["address"]
|
320 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
321 |
+
|
322 |
+
# No available worker
|
323 |
+
if worker_addr == "":
|
324 |
+
# state.messages[-1][-1] = server_error_msg
|
325 |
+
state.update_message(Conversation.ASSISTANT, server_error_msg)
|
326 |
+
yield (
|
327 |
+
state,
|
328 |
+
state.to_gradio_chatbot(),
|
329 |
+
gr.MultimodalTextbox(interactive=False),
|
330 |
+
disable_btn,
|
331 |
+
disable_btn,
|
332 |
+
disable_btn,
|
333 |
+
enable_btn,
|
334 |
+
enable_btn,
|
335 |
+
)
|
336 |
+
return
|
337 |
+
|
338 |
+
all_images = state.get_images(source=state.USER)
|
339 |
+
all_image_paths = [state.save_image(image) for image in all_images]
|
340 |
+
|
341 |
+
# Make requests
|
342 |
+
pload = {
|
343 |
+
"model": model_name,
|
344 |
+
"prompt": state.get_prompt(),
|
345 |
+
"temperature": float(temperature),
|
346 |
+
"top_p": float(top_p),
|
347 |
+
"max_new_tokens": max_new_tokens,
|
348 |
+
"max_input_tiles": max_input_tiles,
|
349 |
+
# "bbox_threshold": bbox_threshold,
|
350 |
+
# "mask_threshold": mask_threshold,
|
351 |
+
"repetition_penalty": repetition_penalty,
|
352 |
+
"images": f"List of {len(all_images)} images: {all_image_paths}",
|
353 |
+
}
|
354 |
+
logger.info(f"==== request ====\n{pload}")
|
355 |
+
pload.pop("images")
|
356 |
+
pload["prompt"] = state.get_prompt(inlude_image=True)
|
357 |
+
state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
|
358 |
+
yield (
|
359 |
+
state,
|
360 |
+
state.to_gradio_chatbot(),
|
361 |
+
gr.MultimodalTextbox(interactive=False),
|
362 |
+
) + (disable_btn,) * 5
|
363 |
+
|
364 |
+
try:
|
365 |
+
# Stream output
|
366 |
+
response = requests.post(
|
367 |
+
worker_addr + "/worker_generate_stream",
|
368 |
+
headers=headers,
|
369 |
+
json=pload,
|
370 |
+
stream=True,
|
371 |
+
timeout=20,
|
372 |
+
)
|
373 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
374 |
+
if chunk:
|
375 |
+
data = json.loads(chunk.decode())
|
376 |
+
if data["error_code"] == 0:
|
377 |
+
if "text" in data:
|
378 |
+
output = data["text"].strip()
|
379 |
+
output += state.streaming_placeholder
|
380 |
+
|
381 |
+
image = None
|
382 |
+
if "image" in data:
|
383 |
+
image = load_image_from_base64(data["image"])
|
384 |
+
_ = state.save_image(image)
|
385 |
+
|
386 |
+
state.update_message(Conversation.ASSISTANT, output, image)
|
387 |
+
yield (
|
388 |
+
state,
|
389 |
+
state.to_gradio_chatbot(),
|
390 |
+
gr.MultimodalTextbox(interactive=False),
|
391 |
+
) + (disable_btn,) * 5
|
392 |
+
else:
|
393 |
+
output = (
|
394 |
+
f"**{data['text']}**" + f" (error_code: {data['error_code']})"
|
395 |
+
)
|
396 |
+
|
397 |
+
state.update_message(Conversation.ASSISTANT, output, None)
|
398 |
+
yield (
|
399 |
+
state,
|
400 |
+
state.to_gradio_chatbot(),
|
401 |
+
gr.MultimodalTextbox(interactive=True),
|
402 |
+
) + (
|
403 |
+
disable_btn,
|
404 |
+
disable_btn,
|
405 |
+
disable_btn,
|
406 |
+
enable_btn,
|
407 |
+
enable_btn,
|
408 |
+
)
|
409 |
+
return
|
410 |
+
except requests.exceptions.RequestException as e:
|
411 |
+
state.update_message(Conversation.ASSISTANT, server_error_msg, None)
|
412 |
+
yield (
|
413 |
+
state,
|
414 |
+
state.to_gradio_chatbot(),
|
415 |
+
gr.MultimodalTextbox(interactive=True),
|
416 |
+
) + (
|
417 |
+
disable_btn,
|
418 |
+
disable_btn,
|
419 |
+
disable_btn,
|
420 |
+
enable_btn,
|
421 |
+
enable_btn,
|
422 |
+
)
|
423 |
+
return
|
424 |
+
|
425 |
+
ai_response = state.return_last_message()
|
426 |
+
if "<ref>" in ai_response:
|
427 |
+
returned_image = find_bounding_boxes(state, ai_response)
|
428 |
+
returned_image = [returned_image] if returned_image else []
|
429 |
+
state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
|
430 |
+
if "```drawing-instruction" in ai_response:
|
431 |
+
returned_image = query_image_generation(
|
432 |
+
ai_response, sd_worker_url=sd_worker_url
|
433 |
+
)
|
434 |
+
returned_image = [returned_image] if returned_image else []
|
435 |
+
state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
|
436 |
+
|
437 |
+
state.end_of_current_turn()
|
438 |
+
|
439 |
+
yield (
|
440 |
+
state,
|
441 |
+
state.to_gradio_chatbot(),
|
442 |
+
gr.MultimodalTextbox(interactive=True),
|
443 |
+
) + (enable_btn,) * 5
|
444 |
+
|
445 |
+
finish_tstamp = time.time()
|
446 |
+
logger.info(f"{output}")
|
447 |
+
data = {
|
448 |
+
"tstamp": round(finish_tstamp, 4),
|
449 |
+
"like": None,
|
450 |
+
"model": model_name,
|
451 |
+
"start": round(start_tstamp, 4),
|
452 |
+
"finish": round(start_tstamp, 4),
|
453 |
+
"state": state.dict(),
|
454 |
+
"images": all_image_paths,
|
455 |
+
"ip": request.client.host,
|
456 |
+
}
|
457 |
+
write2file(get_log_filename(), json.dumps(data) + "\n")
|
458 |
+
|
459 |
+
|
460 |
+
title_html = """
|
461 |
+
<h2> <span class="gradient-text" id="text">InternVL2</span><span class="plain-text">: Better than the Best—Expanding Performance Boundaries of Open-Source Multimodal Models with the Progressive Scaling Strategy</span></h2>
|
462 |
+
<a href="https://internvl.github.io/blog/2024-07-02-InternVL-2.0/">[📜 InternVL2 Blog]</a>
|
463 |
+
<a href="https://huggingface.co/spaces/OpenGVLab/InternVL">[🤗 HF Demo]</a>
|
464 |
+
<a href="https://github.com/OpenGVLab/InternVL?tab=readme-ov-file#quick-start-with-huggingface">[🚀 Quick Start]</a>
|
465 |
+
<a href="https://github.com/OpenGVLab/InternVL/blob/main/document/How_to_use_InternVL_API.md">[🌐 API]</a>
|
466 |
+
"""
|
467 |
+
|
468 |
+
tos_markdown = """
|
469 |
+
### Terms of use
|
470 |
+
By using this service, users are required to agree to the following terms:
|
471 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
472 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
473 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
474 |
+
"""
|
475 |
+
|
476 |
+
|
477 |
+
learn_more_markdown = """
|
478 |
+
### License
|
479 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
480 |
+
|
481 |
+
### Acknowledgement
|
482 |
+
This demo is modified from LLaVA's demo. Thanks for their awesome work!
|
483 |
+
"""
|
484 |
+
# .gradio-container {margin: 5px 10px 0 10px !important};
|
485 |
+
block_css = """
|
486 |
+
.gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;};
|
487 |
+
#buttons button {
|
488 |
+
min-width: min(120px,100%);
|
489 |
+
}
|
490 |
+
|
491 |
+
.gradient-text {
|
492 |
+
font-size: 28px;
|
493 |
+
width: auto;
|
494 |
+
font-weight: bold;
|
495 |
+
background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet);
|
496 |
+
background-clip: text;
|
497 |
+
-webkit-background-clip: text;
|
498 |
+
color: transparent;
|
499 |
+
}
|
500 |
+
|
501 |
+
.plain-text {
|
502 |
+
font-size: 22px;
|
503 |
+
width: auto;
|
504 |
+
font-weight: bold;
|
505 |
+
}
|
506 |
+
"""
|
507 |
+
|
508 |
+
js = """
|
509 |
+
function createWaveAnimation() {
|
510 |
+
const text = document.getElementById('text');
|
511 |
+
var i = 0;
|
512 |
+
setInterval(function() {
|
513 |
+
const colors = [
|
514 |
+
'red, orange, yellow, green, blue, indigo, violet, purple',
|
515 |
+
'orange, yellow, green, blue, indigo, violet, purple, red',
|
516 |
+
'yellow, green, blue, indigo, violet, purple, red, orange',
|
517 |
+
'green, blue, indigo, violet, purple, red, orange, yellow',
|
518 |
+
'blue, indigo, violet, purple, red, orange, yellow, green',
|
519 |
+
'indigo, violet, purple, red, orange, yellow, green, blue',
|
520 |
+
'violet, purple, red, orange, yellow, green, blue, indigo',
|
521 |
+
'purple, red, orange, yellow, green, blue, indigo, violet',
|
522 |
+
];
|
523 |
+
const angle = 45;
|
524 |
+
const colorIndex = i % colors.length;
|
525 |
+
text.style.background = `linear-gradient(${angle}deg, ${colors[colorIndex]})`;
|
526 |
+
text.style.webkitBackgroundClip = 'text';
|
527 |
+
text.style.backgroundClip = 'text';
|
528 |
+
text.style.color = 'transparent';
|
529 |
+
text.style.fontSize = '28px';
|
530 |
+
text.style.width = 'auto';
|
531 |
+
text.textContent = 'InternVL2';
|
532 |
+
text.style.fontWeight = 'bold';
|
533 |
+
i += 1;
|
534 |
+
}, 200);
|
535 |
+
const params = new URLSearchParams(window.location.search);
|
536 |
+
url_params = Object.fromEntries(params);
|
537 |
+
console.log(url_params);
|
538 |
+
return url_params;
|
539 |
+
}
|
540 |
+
|
541 |
+
"""
|
542 |
+
|
543 |
+
|
544 |
+
def build_demo(embed_mode):
|
545 |
+
textbox = gr.MultimodalTextbox(
|
546 |
+
interactive=True,
|
547 |
+
file_types=["image", "video"],
|
548 |
+
placeholder="Enter message or upload file...",
|
549 |
+
show_label=False,
|
550 |
+
)
|
551 |
+
|
552 |
+
with gr.Blocks(
|
553 |
+
title="InternVL-Chat",
|
554 |
+
theme=gr.themes.Default(),
|
555 |
+
css=block_css,
|
556 |
+
) as demo:
|
557 |
+
state = gr.State()
|
558 |
+
|
559 |
+
if not embed_mode:
|
560 |
+
# gr.Markdown(title_markdown)
|
561 |
+
gr.HTML(title_html)
|
562 |
+
|
563 |
+
with gr.Row():
|
564 |
+
with gr.Column(scale=2):
|
565 |
+
|
566 |
+
with gr.Row(elem_id="model_selector_row"):
|
567 |
+
model_selector = gr.Dropdown(
|
568 |
+
choices=models,
|
569 |
+
value=models[0] if len(models) > 0 else "",
|
570 |
+
# value="InternVL-Chat-V1-5",
|
571 |
+
interactive=True,
|
572 |
+
show_label=False,
|
573 |
+
container=False,
|
574 |
+
)
|
575 |
+
|
576 |
+
with gr.Accordion("System Prompt", open=False) as system_prompt_row:
|
577 |
+
system_prompt = gr.Textbox(
|
578 |
+
value="请尽可能详细地回答用户的问题。",
|
579 |
+
label="System Prompt",
|
580 |
+
interactive=True,
|
581 |
+
)
|
582 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
583 |
+
temperature = gr.Slider(
|
584 |
+
minimum=0.0,
|
585 |
+
maximum=1.0,
|
586 |
+
value=0.2,
|
587 |
+
step=0.1,
|
588 |
+
interactive=True,
|
589 |
+
label="Temperature",
|
590 |
+
)
|
591 |
+
top_p = gr.Slider(
|
592 |
+
minimum=0.0,
|
593 |
+
maximum=1.0,
|
594 |
+
value=0.7,
|
595 |
+
step=0.1,
|
596 |
+
interactive=True,
|
597 |
+
label="Top P",
|
598 |
+
)
|
599 |
+
repetition_penalty = gr.Slider(
|
600 |
+
minimum=1.0,
|
601 |
+
maximum=1.5,
|
602 |
+
value=1.1,
|
603 |
+
step=0.02,
|
604 |
+
interactive=True,
|
605 |
+
label="Repetition penalty",
|
606 |
+
)
|
607 |
+
max_output_tokens = gr.Slider(
|
608 |
+
minimum=0,
|
609 |
+
maximum=4096,
|
610 |
+
value=1024,
|
611 |
+
step=64,
|
612 |
+
interactive=True,
|
613 |
+
label="Max output tokens",
|
614 |
+
)
|
615 |
+
max_input_tiles = gr.Slider(
|
616 |
+
minimum=1,
|
617 |
+
maximum=32,
|
618 |
+
value=12,
|
619 |
+
step=1,
|
620 |
+
interactive=True,
|
621 |
+
label="Max input tiles (control the image size)",
|
622 |
+
)
|
623 |
+
examples = gr.Examples(
|
624 |
+
examples=[
|
625 |
+
[
|
626 |
+
{
|
627 |
+
"files": [
|
628 |
+
"gallery/prod_9.jpg",
|
629 |
+
],
|
630 |
+
"text": "What's at the far end of the image?",
|
631 |
+
}
|
632 |
+
],
|
633 |
+
[
|
634 |
+
{
|
635 |
+
"files": [
|
636 |
+
"gallery/astro_on_unicorn.png",
|
637 |
+
],
|
638 |
+
"text": "What does this image mean?",
|
639 |
+
}
|
640 |
+
],
|
641 |
+
[
|
642 |
+
{
|
643 |
+
"files": [
|
644 |
+
"gallery/prod_12.png",
|
645 |
+
],
|
646 |
+
"text": "What are the consequences of the easy decisions shown in this image?",
|
647 |
+
}
|
648 |
+
],
|
649 |
+
[
|
650 |
+
{
|
651 |
+
"files": [
|
652 |
+
"gallery/child_1.jpg",
|
653 |
+
"gallery/child_2.jpg",
|
654 |
+
f"gallery/child_3.jpg",
|
655 |
+
],
|
656 |
+
"text": "这三帧图片讲述了一件什么事情?",
|
657 |
+
}
|
658 |
+
],
|
659 |
+
],
|
660 |
+
inputs=[textbox],
|
661 |
+
)
|
662 |
+
|
663 |
+
with gr.Column(scale=8):
|
664 |
+
chatbot = gr.Chatbot(
|
665 |
+
elem_id="chatbot",
|
666 |
+
label="InternVL2",
|
667 |
+
height=580,
|
668 |
+
show_copy_button=True,
|
669 |
+
show_share_button=True,
|
670 |
+
avatar_images=[
|
671 |
+
"assets/human.png",
|
672 |
+
"assets/assistant.png",
|
673 |
+
],
|
674 |
+
bubble_full_width=False,
|
675 |
+
)
|
676 |
+
with gr.Row():
|
677 |
+
with gr.Column(scale=8):
|
678 |
+
textbox.render()
|
679 |
+
with gr.Column(scale=1, min_width=50):
|
680 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
681 |
+
with gr.Row(elem_id="buttons") as button_row:
|
682 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
683 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
684 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
685 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
686 |
+
regenerate_btn = gr.Button(
|
687 |
+
value="🔄 Regenerate", interactive=False
|
688 |
+
)
|
689 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
690 |
+
|
691 |
+
if not embed_mode:
|
692 |
+
gr.Markdown(tos_markdown)
|
693 |
+
gr.Markdown(learn_more_markdown)
|
694 |
+
url_params = gr.JSON(visible=False)
|
695 |
+
|
696 |
+
# Register listeners
|
697 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
698 |
+
upvote_btn.click(
|
699 |
+
upvote_last_response,
|
700 |
+
[state, model_selector],
|
701 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
702 |
+
)
|
703 |
+
downvote_btn.click(
|
704 |
+
downvote_last_response,
|
705 |
+
[state, model_selector],
|
706 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
707 |
+
)
|
708 |
+
chatbot.like(
|
709 |
+
vote_selected_response,
|
710 |
+
[state, model_selector],
|
711 |
+
[],
|
712 |
+
)
|
713 |
+
flag_btn.click(
|
714 |
+
flag_last_response,
|
715 |
+
[state, model_selector],
|
716 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
717 |
+
)
|
718 |
+
regenerate_btn.click(
|
719 |
+
regenerate,
|
720 |
+
[state, system_prompt],
|
721 |
+
[state, chatbot, textbox] + btn_list,
|
722 |
+
).then(
|
723 |
+
http_bot,
|
724 |
+
[
|
725 |
+
state,
|
726 |
+
model_selector,
|
727 |
+
temperature,
|
728 |
+
top_p,
|
729 |
+
repetition_penalty,
|
730 |
+
max_output_tokens,
|
731 |
+
max_input_tiles,
|
732 |
+
# bbox_threshold,
|
733 |
+
# mask_threshold,
|
734 |
+
],
|
735 |
+
[state, chatbot, textbox] + btn_list,
|
736 |
+
)
|
737 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
738 |
+
|
739 |
+
textbox.submit(
|
740 |
+
add_text,
|
741 |
+
[state, textbox, system_prompt],
|
742 |
+
[state, chatbot, textbox] + btn_list,
|
743 |
+
).then(
|
744 |
+
http_bot,
|
745 |
+
[
|
746 |
+
state,
|
747 |
+
model_selector,
|
748 |
+
temperature,
|
749 |
+
top_p,
|
750 |
+
repetition_penalty,
|
751 |
+
max_output_tokens,
|
752 |
+
max_input_tiles,
|
753 |
+
# bbox_threshold,
|
754 |
+
# mask_threshold,
|
755 |
+
],
|
756 |
+
[state, chatbot, textbox] + btn_list,
|
757 |
+
)
|
758 |
+
submit_btn.click(
|
759 |
+
add_text,
|
760 |
+
[state, textbox, system_prompt],
|
761 |
+
[state, chatbot, textbox] + btn_list,
|
762 |
+
).then(
|
763 |
+
http_bot,
|
764 |
+
[
|
765 |
+
state,
|
766 |
+
model_selector,
|
767 |
+
temperature,
|
768 |
+
top_p,
|
769 |
+
repetition_penalty,
|
770 |
+
max_output_tokens,
|
771 |
+
max_input_tiles,
|
772 |
+
# bbox_threshold,
|
773 |
+
# mask_threshold,
|
774 |
+
],
|
775 |
+
[state, chatbot, textbox] + btn_list,
|
776 |
+
)
|
777 |
+
|
778 |
+
if args.model_list_mode == "once":
|
779 |
+
demo.load(
|
780 |
+
load_demo,
|
781 |
+
[url_params],
|
782 |
+
[state, model_selector],
|
783 |
+
js=js,
|
784 |
+
)
|
785 |
+
elif args.model_list_mode == "reload":
|
786 |
+
demo.load(
|
787 |
+
load_demo_refresh_model_list,
|
788 |
+
None,
|
789 |
+
[state, model_selector],
|
790 |
+
js=js,
|
791 |
+
)
|
792 |
+
else:
|
793 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
794 |
+
|
795 |
+
return demo
|
796 |
+
|
797 |
+
|
798 |
+
if __name__ == "__main__":
|
799 |
+
parser = argparse.ArgumentParser()
|
800 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
801 |
+
parser.add_argument("--port", type=int, default=11000)
|
802 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
803 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
804 |
+
parser.add_argument(
|
805 |
+
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
806 |
+
)
|
807 |
+
parser.add_argument("--sd-worker-url", type=str, default=None)
|
808 |
+
parser.add_argument("--share", action="store_true")
|
809 |
+
parser.add_argument("--moderate", action="store_true")
|
810 |
+
parser.add_argument("--embed", action="store_true")
|
811 |
+
args = parser.parse_args()
|
812 |
+
logger.info(f"args: {args}")
|
813 |
+
|
814 |
+
models = get_model_list()
|
815 |
+
|
816 |
+
sd_worker_url = args.sd_worker_url
|
817 |
+
logger.info(args)
|
818 |
+
demo = build_demo(args.embed)
|
819 |
+
demo.queue(api_open=False).launch(
|
820 |
+
server_name=args.host,
|
821 |
+
server_port=args.port,
|
822 |
+
share=args.share,
|
823 |
+
max_threads=args.concurrency_count,
|
824 |
+
)
|
library.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
# --------------------------------------------------------
|
2 |
-
# InternVL
|
3 |
-
# Copyright (c) 2024 OpenGVLab
|
4 |
-
# Licensed under The MIT License [see LICENSE for details]
|
5 |
-
# Modified from https://github.com/hreikin/streamlit-uploads-library/blob/main/streamlit_uploads_library/library.py
|
6 |
-
# --------------------------------------------------------
|
7 |
-
|
8 |
-
import logging
|
9 |
-
from math import ceil
|
10 |
-
|
11 |
-
import streamlit as st
|
12 |
-
|
13 |
-
logger = logging.getLogger(__name__)
|
14 |
-
|
15 |
-
|
16 |
-
class Library():
|
17 |
-
"""Create a simple library out of streamlit widgets.
|
18 |
-
|
19 |
-
Using the library is simple, import `streamlit_uploads_library` and then instantiate the class with the
|
20 |
-
required `directory` variable. Other options can be configured by passing in different variables
|
21 |
-
when instantiating the class.
|
22 |
-
|
23 |
-
Example Usage:
|
24 |
-
python
|
25 |
-
import streamlit as st
|
26 |
-
from library import Library
|
27 |
-
|
28 |
-
st.set_page_config(page_title="Streamlit Uploads Library", layout="wide")
|
29 |
-
default_library = Library(images=pil_images)
|
30 |
-
"""
|
31 |
-
|
32 |
-
def __init__(self, images, image_alignment='end', number_of_columns=5):
|
33 |
-
self.images = images
|
34 |
-
self.image_alignment = image_alignment
|
35 |
-
self.number_of_columns = number_of_columns
|
36 |
-
self.root_container = self.create(images=self.images,
|
37 |
-
image_alignment=self.image_alignment,
|
38 |
-
number_of_columns=self.number_of_columns)
|
39 |
-
|
40 |
-
def create(_self, images, image_alignment, number_of_columns):
|
41 |
-
"""Creates a simple library or gallery with columns.
|
42 |
-
|
43 |
-
Creates a library or gallery using columns out of streamlit widgets.
|
44 |
-
"""
|
45 |
-
root_container = st.container()
|
46 |
-
with root_container:
|
47 |
-
# To be able to display the images, details and buttons all in one row and aligned
|
48 |
-
# correctly so that images of different sizes don't affect the alignment of the details
|
49 |
-
# and buttons we need do some minor maths and keep track of multiple index values.
|
50 |
-
# First we instantiate some defaults.
|
51 |
-
col_idx = 0
|
52 |
-
filename_idx = 0
|
53 |
-
max_idx = number_of_columns - 1
|
54 |
-
# Get the file list and filename list, work out the total number of files from the
|
55 |
-
# length of the file list.
|
56 |
-
library_files = images
|
57 |
-
num_of_files = len(library_files)
|
58 |
-
# Work out the number of rows required by dividing the number of files by the number of
|
59 |
-
# columns and rounding up using `math.ceil`.
|
60 |
-
num_of_rows_req = ceil(num_of_files / number_of_columns)
|
61 |
-
# Create the required number of rows (st.container).
|
62 |
-
library_rows = list()
|
63 |
-
library_rows_idx = 0
|
64 |
-
for i in range(num_of_rows_req):
|
65 |
-
library_rows.append(st.container())
|
66 |
-
# For each library row we need to create separate rows (st.container) for images,
|
67 |
-
# and rows (st.expander) for details and buttons to keep them in the correct columns.
|
68 |
-
for idx in range(num_of_rows_req):
|
69 |
-
with library_rows[library_rows_idx]:
|
70 |
-
imgs_columns = list(st.columns(number_of_columns))
|
71 |
-
# Since we are keeping track of the column and filename indexes we can use
|
72 |
-
# those to slice the `library_files` list at the correct points for each row
|
73 |
-
# and then increase or reset the indexes as required.
|
74 |
-
for img in library_files[filename_idx:(filename_idx + number_of_columns)]:
|
75 |
-
with imgs_columns[col_idx]:
|
76 |
-
st.image(img, use_column_width='auto')
|
77 |
-
st.write(
|
78 |
-
f"""<style>
|
79 |
-
[data-testid="stHorizontalBlock"] {{
|
80 |
-
align-items: {image_alignment};
|
81 |
-
}}
|
82 |
-
</style>
|
83 |
-
""",
|
84 |
-
unsafe_allow_html=True
|
85 |
-
)
|
86 |
-
# Keeps track of the current column, if we reach the `max_idx` we reset it
|
87 |
-
# to 0 and increase the row index. This combined with the slicing should
|
88 |
-
# ensure all images, details and buttons are in the correct columns.
|
89 |
-
if col_idx < max_idx:
|
90 |
-
col_idx += 1
|
91 |
-
else:
|
92 |
-
col_idx = 0
|
93 |
-
library_rows_idx += 1
|
94 |
-
filename_idx += 1
|
95 |
-
return root_container
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mm_utils.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
import base64
|
2 |
-
from io import BytesIO
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from PIL import Image
|
6 |
-
from transformers import StoppingCriteria
|
7 |
-
|
8 |
-
from .constants import IMAGE_TOKEN_INDEX
|
9 |
-
|
10 |
-
|
11 |
-
def load_image_from_base64(image):
|
12 |
-
return Image.open(BytesIO(base64.b64decode(image)))
|
13 |
-
|
14 |
-
|
15 |
-
def expand2square(pil_img, background_color):
|
16 |
-
width, height = pil_img.size
|
17 |
-
if width == height:
|
18 |
-
return pil_img
|
19 |
-
elif width > height:
|
20 |
-
result = Image.new(pil_img.mode, (width, width), background_color)
|
21 |
-
result.paste(pil_img, (0, (width - height) // 2))
|
22 |
-
return result
|
23 |
-
else:
|
24 |
-
result = Image.new(pil_img.mode, (height, height), background_color)
|
25 |
-
result.paste(pil_img, ((height - width) // 2, 0))
|
26 |
-
return result
|
27 |
-
|
28 |
-
|
29 |
-
def process_images(images, image_processor, model_cfg):
|
30 |
-
image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
|
31 |
-
new_images = []
|
32 |
-
if image_aspect_ratio == 'pad':
|
33 |
-
for image in images:
|
34 |
-
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
35 |
-
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
36 |
-
new_images.append(image)
|
37 |
-
else:
|
38 |
-
return image_processor(images, return_tensors='pt')['pixel_values']
|
39 |
-
if all(x.shape == new_images[0].shape for x in new_images):
|
40 |
-
new_images = torch.stack(new_images, dim=0)
|
41 |
-
return new_images
|
42 |
-
|
43 |
-
|
44 |
-
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
45 |
-
num_image_tokens=None, return_tensors=None):
|
46 |
-
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
47 |
-
|
48 |
-
def insert_separator(X, sep):
|
49 |
-
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
50 |
-
|
51 |
-
input_ids = []
|
52 |
-
offset = 0
|
53 |
-
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
54 |
-
offset = 1
|
55 |
-
input_ids.append(prompt_chunks[0][0])
|
56 |
-
|
57 |
-
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + num_image_tokens)):
|
58 |
-
input_ids.extend(x[offset:])
|
59 |
-
|
60 |
-
if return_tensors is not None:
|
61 |
-
if return_tensors == 'pt':
|
62 |
-
return torch.tensor(input_ids, dtype=torch.long)
|
63 |
-
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
64 |
-
return input_ids
|
65 |
-
|
66 |
-
|
67 |
-
def get_model_name_from_path(model_path):
|
68 |
-
model_path = model_path.strip('/')
|
69 |
-
model_paths = model_path.split('/')
|
70 |
-
if model_paths[-1].startswith('checkpoint-'):
|
71 |
-
return model_paths[-2] + '_' + model_paths[-1]
|
72 |
-
else:
|
73 |
-
return model_paths[-1]
|
74 |
-
|
75 |
-
|
76 |
-
class KeywordsStoppingCriteria(StoppingCriteria):
|
77 |
-
def __init__(self, keywords, tokenizer, input_ids):
|
78 |
-
self.keywords = keywords
|
79 |
-
self.keyword_ids = []
|
80 |
-
self.max_keyword_len = 0
|
81 |
-
for keyword in keywords:
|
82 |
-
cur_keyword_ids = tokenizer(keyword).input_ids
|
83 |
-
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
84 |
-
cur_keyword_ids = cur_keyword_ids[1:]
|
85 |
-
if len(cur_keyword_ids) > self.max_keyword_len:
|
86 |
-
self.max_keyword_len = len(cur_keyword_ids)
|
87 |
-
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
88 |
-
self.tokenizer = tokenizer
|
89 |
-
self.start_len = input_ids.shape[1]
|
90 |
-
|
91 |
-
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
92 |
-
assert output_ids.shape[0] == 1, 'Only support batch size 1 (yet)' # TODO
|
93 |
-
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
94 |
-
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
95 |
-
for keyword_id in self.keyword_ids:
|
96 |
-
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
97 |
-
return True
|
98 |
-
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
99 |
-
for keyword in self.keywords:
|
100 |
-
if keyword in outputs:
|
101 |
-
return True
|
102 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_worker.py
CHANGED
@@ -9,14 +9,15 @@ A model worker executes the model.
|
|
9 |
"""
|
10 |
import argparse
|
11 |
import asyncio
|
12 |
-
|
13 |
import json
|
14 |
-
import
|
15 |
import threading
|
16 |
import time
|
17 |
import uuid
|
|
|
18 |
from functools import partial
|
19 |
-
|
20 |
from threading import Thread
|
21 |
|
22 |
import requests
|
@@ -28,33 +29,36 @@ from fastapi import BackgroundTasks, FastAPI, Request
|
|
28 |
from fastapi.responses import StreamingResponse
|
29 |
from PIL import Image
|
30 |
from torchvision.transforms.functional import InterpolationMode
|
31 |
-
from transformers import
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
worker_id = str(uuid.uuid4())[:6]
|
36 |
-
logger = build_logger(
|
37 |
global_counter = 0
|
38 |
model_semaphore = None
|
39 |
|
40 |
|
41 |
-
def load_image_from_base64(image):
|
42 |
-
return Image.open(BytesIO(base64.b64decode(image)))
|
43 |
-
|
44 |
-
|
45 |
def build_transform(input_size):
|
46 |
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
47 |
-
transform = T.Compose(
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
return transform
|
54 |
|
55 |
|
56 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
57 |
-
best_ratio_diff = float(
|
58 |
best_ratio = (1, 1)
|
59 |
area = width * height
|
60 |
for ratio in target_ratios:
|
@@ -69,19 +73,26 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
|
|
69 |
return best_ratio
|
70 |
|
71 |
|
72 |
-
def dynamic_preprocess(
|
|
|
|
|
73 |
orig_width, orig_height = image.size
|
74 |
aspect_ratio = orig_width / orig_height
|
75 |
|
76 |
# calculate the existing image aspect ratio
|
77 |
target_ratios = set(
|
78 |
-
(i, j)
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
81 |
|
82 |
# find the closest aspect ratio to the target
|
83 |
target_aspect_ratio = find_closest_aspect_ratio(
|
84 |
-
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
|
85 |
|
86 |
# calculate the target width and height
|
87 |
target_width = image_size * target_aspect_ratio[0]
|
@@ -96,7 +107,7 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
|
|
96 |
(i % (target_width // image_size)) * image_size,
|
97 |
(i // (target_width // image_size)) * image_size,
|
98 |
((i % (target_width // image_size)) + 1) * image_size,
|
99 |
-
((i // (target_width // image_size)) + 1) * image_size
|
100 |
)
|
101 |
# split the image
|
102 |
split_img = resized_img.crop(box)
|
@@ -114,78 +125,163 @@ def heart_beat_worker(controller):
|
|
114 |
controller.send_heart_beat()
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
class ModelWorker:
|
118 |
-
def __init__(
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
self.controller_addr = controller_addr
|
121 |
self.worker_addr = worker_addr
|
122 |
self.worker_id = worker_id
|
123 |
-
if model_path.endswith(
|
124 |
model_path = model_path[:-1]
|
125 |
if model_name is None:
|
126 |
-
model_paths = model_path.split(
|
127 |
-
if model_paths[-1].startswith(
|
128 |
-
self.model_name = model_paths[-2] +
|
129 |
else:
|
130 |
self.model_name = model_paths[-1]
|
131 |
else:
|
132 |
self.model_name = model_name
|
133 |
|
134 |
-
logger.info(f
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
model_path,
|
142 |
load_in_8bit=load_8bit,
|
143 |
-
torch_dtype=torch.
|
144 |
-
device_map=
|
145 |
-
trust_remote_code=True
|
|
|
146 |
else:
|
147 |
-
self.model =
|
148 |
model_path,
|
149 |
load_in_8bit=load_8bit,
|
150 |
-
torch_dtype=torch.
|
151 |
-
trust_remote_code=True
|
152 |
-
|
|
|
153 |
self.model = self.model.cuda()
|
|
|
|
|
|
|
154 |
self.image_size = self.model.config.force_image_size
|
155 |
self.context_len = context_len
|
156 |
self.register_to_controller()
|
157 |
self.heart_beat_thread = threading.Thread(
|
158 |
-
target=heart_beat_worker, args=(self,)
|
|
|
159 |
self.heart_beat_thread.start()
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
def register_to_controller(self):
|
162 |
-
logger.info(
|
163 |
|
164 |
-
url = self.controller_addr +
|
165 |
data = {
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
}
|
170 |
r = requests.post(url, json=data)
|
171 |
assert r.status_code == 200
|
172 |
|
173 |
def send_heart_beat(self):
|
174 |
-
logger.info(
|
175 |
-
|
176 |
-
|
|
|
|
|
177 |
|
178 |
-
url = self.controller_addr +
|
179 |
|
180 |
while True:
|
181 |
try:
|
182 |
-
ret = requests.post(
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
186 |
break
|
187 |
except requests.exceptions.RequestException as e:
|
188 |
-
logger.error(f
|
189 |
time.sleep(5)
|
190 |
|
191 |
if not exist:
|
@@ -195,80 +291,115 @@ class ModelWorker:
|
|
195 |
if model_semaphore is None:
|
196 |
return 0
|
197 |
else:
|
198 |
-
return
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
def get_status(self):
|
202 |
return {
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
}
|
207 |
|
|
|
208 |
@torch.inference_mode()
|
209 |
def generate_stream(self, params):
|
210 |
-
system_message = params[
|
211 |
-
send_messages = params[
|
212 |
-
max_input_tiles = params[
|
213 |
-
temperature = params[
|
214 |
-
top_p = params[
|
215 |
-
max_new_tokens = params[
|
216 |
-
repetition_penalty = params[
|
217 |
do_sample = True if temperature > 0.0 else False
|
218 |
|
219 |
-
global_image_cnt =
|
220 |
history, pil_images, max_input_tile_list = [], [], []
|
221 |
for message in send_messages:
|
222 |
-
if message[
|
223 |
-
prefix =
|
224 |
-
if
|
225 |
max_input_tile_temp = []
|
226 |
-
for image_str in message[
|
227 |
pil_images.append(load_image_from_base64(image_str))
|
228 |
-
prefix += f
|
229 |
global_image_cnt += 1
|
230 |
-
max_input_tile_temp.append(
|
|
|
|
|
231 |
if len(max_input_tile_temp) > 0:
|
232 |
max_input_tile_list.append(max_input_tile_temp)
|
233 |
-
content = prefix + message[
|
234 |
-
history.append(
|
|
|
|
|
|
|
|
|
235 |
else:
|
236 |
-
history[-1].append(message[
|
237 |
question, history = history[-1][0], history[:-1]
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
# Create a new list to store processed sublists
|
240 |
flattened_list = []
|
241 |
# Iterate through all but the last sublist in max_input_tile_list and process them
|
242 |
for sublist in max_input_tile_list[:-1]:
|
243 |
-
processed_sublist = [1] * len(
|
244 |
-
|
|
|
|
|
|
|
|
|
245 |
# If max_input_tile_list is not empty, add the last sublist to the new list
|
246 |
if max_input_tile_list:
|
247 |
flattened_list.extend(max_input_tile_list[-1])
|
248 |
max_input_tile_list = flattened_list
|
249 |
-
assert len(max_input_tile_list) == len(
|
250 |
-
|
|
|
251 |
|
252 |
old_system_message = self.model.system_message
|
253 |
self.model.system_message = system_message
|
254 |
image_tiles = []
|
255 |
transform = build_transform(input_size=self.image_size)
|
256 |
if len(pil_images) > 0:
|
257 |
-
for current_max_input_tiles, pil_image in zip(
|
|
|
|
|
258 |
if self.model.config.dynamic_image_size:
|
259 |
tiles = dynamic_preprocess(
|
260 |
-
pil_image,
|
261 |
-
|
|
|
|
|
|
|
262 |
else:
|
263 |
tiles = [pil_image]
|
264 |
image_tiles += tiles
|
265 |
pixel_values = [transform(item) for item in image_tiles]
|
266 |
-
pixel_values = torch.stack(pixel_values).to(
|
267 |
-
|
|
|
|
|
268 |
else:
|
269 |
pixel_values = None
|
270 |
|
271 |
-
streamer = TextIteratorStreamer(
|
|
|
|
|
272 |
generation_config = dict(
|
273 |
num_beams=1,
|
274 |
max_new_tokens=max_new_tokens,
|
@@ -279,53 +410,61 @@ class ModelWorker:
|
|
279 |
top_p=top_p,
|
280 |
streamer=streamer,
|
281 |
)
|
282 |
-
logger.info(
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
286 |
tokenizer=self.tokenizer,
|
287 |
pixel_values=pixel_values,
|
288 |
question=question,
|
289 |
history=history,
|
290 |
return_history=False,
|
291 |
generation_config=generation_config,
|
292 |
-
)
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
303 |
|
304 |
def generate_stream_gate(self, params):
|
305 |
try:
|
306 |
for x in self.generate_stream(params):
|
307 |
yield x
|
308 |
except ValueError as e:
|
309 |
-
print(
|
|
|
310 |
ret = {
|
311 |
-
|
312 |
-
|
313 |
}
|
314 |
-
yield json.dumps(ret).encode() + b
|
315 |
except torch.cuda.CudaError as e:
|
316 |
-
|
|
|
317 |
ret = {
|
318 |
-
|
319 |
-
|
320 |
}
|
321 |
-
yield json.dumps(ret).encode() + b
|
322 |
except Exception as e:
|
323 |
-
|
|
|
324 |
ret = {
|
325 |
-
|
326 |
-
|
327 |
}
|
328 |
-
yield json.dumps(ret).encode() + b
|
329 |
|
330 |
|
331 |
app = FastAPI()
|
@@ -337,7 +476,7 @@ def release_model_semaphore(fn=None):
|
|
337 |
fn()
|
338 |
|
339 |
|
340 |
-
@app.post(
|
341 |
async def generate_stream(request: Request):
|
342 |
global model_semaphore, global_counter
|
343 |
global_counter += 1
|
@@ -349,35 +488,39 @@ async def generate_stream(request: Request):
|
|
349 |
worker.send_heart_beat()
|
350 |
generator = worker.generate_stream_gate(params)
|
351 |
background_tasks = BackgroundTasks()
|
352 |
-
background_tasks.add_task(
|
|
|
|
|
353 |
return StreamingResponse(generator, background=background_tasks)
|
354 |
|
355 |
|
356 |
-
@app.post(
|
357 |
async def get_status(request: Request):
|
358 |
return worker.get_status()
|
359 |
|
360 |
|
361 |
-
if __name__ ==
|
362 |
parser = argparse.ArgumentParser()
|
363 |
-
parser.add_argument(
|
364 |
-
parser.add_argument(
|
365 |
-
parser.add_argument(
|
366 |
-
parser.add_argument(
|
367 |
-
parser.add_argument(
|
368 |
-
parser.add_argument(
|
369 |
-
parser.add_argument(
|
370 |
-
parser.add_argument(
|
371 |
-
parser.add_argument(
|
372 |
-
parser.add_argument(
|
373 |
args = parser.parse_args()
|
374 |
-
logger.info(f
|
375 |
-
|
376 |
-
worker = ModelWorker(
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
9 |
"""
|
10 |
import argparse
|
11 |
import asyncio
|
12 |
+
|
13 |
import json
|
14 |
+
import math
|
15 |
import threading
|
16 |
import time
|
17 |
import uuid
|
18 |
+
import traceback
|
19 |
from functools import partial
|
20 |
+
|
21 |
from threading import Thread
|
22 |
|
23 |
import requests
|
|
|
29 |
from fastapi.responses import StreamingResponse
|
30 |
from PIL import Image
|
31 |
from torchvision.transforms.functional import InterpolationMode
|
32 |
+
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
|
33 |
+
from utils import (
|
34 |
+
build_logger,
|
35 |
+
pretty_print_semaphore,
|
36 |
+
server_error_msg,
|
37 |
+
load_image_from_base64,
|
38 |
+
)
|
39 |
+
import spaces
|
40 |
|
41 |
worker_id = str(uuid.uuid4())[:6]
|
42 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
43 |
global_counter = 0
|
44 |
model_semaphore = None
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
47 |
def build_transform(input_size):
|
48 |
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
49 |
+
transform = T.Compose(
|
50 |
+
[
|
51 |
+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
52 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
53 |
+
T.ToTensor(),
|
54 |
+
T.Normalize(mean=MEAN, std=STD),
|
55 |
+
]
|
56 |
+
)
|
57 |
return transform
|
58 |
|
59 |
|
60 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
61 |
+
best_ratio_diff = float("inf")
|
62 |
best_ratio = (1, 1)
|
63 |
area = width * height
|
64 |
for ratio in target_ratios:
|
|
|
73 |
return best_ratio
|
74 |
|
75 |
|
76 |
+
def dynamic_preprocess(
|
77 |
+
image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
|
78 |
+
):
|
79 |
orig_width, orig_height = image.size
|
80 |
aspect_ratio = orig_width / orig_height
|
81 |
|
82 |
# calculate the existing image aspect ratio
|
83 |
target_ratios = set(
|
84 |
+
(i, j)
|
85 |
+
for n in range(min_num, max_num + 1)
|
86 |
+
for i in range(1, n + 1)
|
87 |
+
for j in range(1, n + 1)
|
88 |
+
if i * j <= max_num and i * j >= min_num
|
89 |
+
)
|
90 |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
91 |
|
92 |
# find the closest aspect ratio to the target
|
93 |
target_aspect_ratio = find_closest_aspect_ratio(
|
94 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
95 |
+
)
|
96 |
|
97 |
# calculate the target width and height
|
98 |
target_width = image_size * target_aspect_ratio[0]
|
|
|
107 |
(i % (target_width // image_size)) * image_size,
|
108 |
(i // (target_width // image_size)) * image_size,
|
109 |
((i % (target_width // image_size)) + 1) * image_size,
|
110 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
111 |
)
|
112 |
# split the image
|
113 |
split_img = resized_img.crop(box)
|
|
|
125 |
controller.send_heart_beat()
|
126 |
|
127 |
|
128 |
+
def split_model(model_name):
|
129 |
+
device_map = {}
|
130 |
+
world_size = torch.cuda.device_count()
|
131 |
+
num_layers = {
|
132 |
+
"InternVL2-8B": 32,
|
133 |
+
"InternVL2-26B": 48,
|
134 |
+
"InternVL2-40B": 60,
|
135 |
+
"InternVL2-Llama3-76B": 80,
|
136 |
+
"InternVL2-78B": 80,
|
137 |
+
"InternVL2-Pro": 80,
|
138 |
+
}[model_name]
|
139 |
+
# Since the first GPU will be used for ViT, treat it as half a GPU.
|
140 |
+
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
|
141 |
+
num_layers_per_gpu = [num_layers_per_gpu] * world_size
|
142 |
+
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
|
143 |
+
layer_cnt = 0
|
144 |
+
for i, num_layer in enumerate(num_layers_per_gpu):
|
145 |
+
for j in range(num_layer):
|
146 |
+
device_map[f"language_model.model.layers.{layer_cnt}"] = i
|
147 |
+
layer_cnt += 1
|
148 |
+
device_map["vision_model"] = 0
|
149 |
+
device_map["mlp1"] = 0
|
150 |
+
device_map["language_model.model.tok_embeddings"] = 0
|
151 |
+
device_map["language_model.model.embed_tokens"] = 0
|
152 |
+
device_map["language_model.output"] = 0
|
153 |
+
device_map["language_model.model.norm"] = 0
|
154 |
+
device_map["language_model.lm_head"] = 0
|
155 |
+
device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
|
156 |
+
|
157 |
+
return device_map
|
158 |
+
|
159 |
+
|
160 |
class ModelWorker:
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
controller_addr,
|
164 |
+
worker_addr,
|
165 |
+
worker_id,
|
166 |
+
model_path,
|
167 |
+
model_name,
|
168 |
+
load_8bit,
|
169 |
+
device,
|
170 |
+
context_len=8192,
|
171 |
+
):
|
172 |
self.controller_addr = controller_addr
|
173 |
self.worker_addr = worker_addr
|
174 |
self.worker_id = worker_id
|
175 |
+
if model_path.endswith("/"):
|
176 |
model_path = model_path[:-1]
|
177 |
if model_name is None:
|
178 |
+
model_paths = model_path.split("/")
|
179 |
+
if model_paths[-1].startswith("checkpoint-"):
|
180 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
181 |
else:
|
182 |
self.model_name = model_paths[-1]
|
183 |
else:
|
184 |
self.model_name = model_name
|
185 |
|
186 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
187 |
|
188 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
189 |
+
model_path, trust_remote_code=True, use_fast=False
|
190 |
+
)
|
191 |
+
tokens_to_keep = ["<box>", "</box>", "<ref>", "</ref>"]
|
192 |
+
tokenizer.additional_special_tokens = [
|
193 |
+
item
|
194 |
+
for item in tokenizer.additional_special_tokens
|
195 |
+
if item not in tokens_to_keep
|
196 |
+
]
|
197 |
+
self.tokenizer = tokenizer
|
198 |
+
|
199 |
+
if device == "auto":
|
200 |
+
device_map = split_model(self.model_name)
|
201 |
+
self.model = AutoModel.from_pretrained(
|
202 |
model_path,
|
203 |
load_in_8bit=load_8bit,
|
204 |
+
torch_dtype=torch.bfloat16,
|
205 |
+
device_map=device_map,
|
206 |
+
trust_remote_code=True,
|
207 |
+
).eval()
|
208 |
else:
|
209 |
+
self.model = AutoModel.from_pretrained(
|
210 |
model_path,
|
211 |
load_in_8bit=load_8bit,
|
212 |
+
torch_dtype=torch.bfloat16,
|
213 |
+
trust_remote_code=True,
|
214 |
+
).eval()
|
215 |
+
if not load_8bit and not device == "auto":
|
216 |
self.model = self.model.cuda()
|
217 |
+
self.load_8bit = load_8bit
|
218 |
+
self.device = device
|
219 |
+
self.model_path = model_path
|
220 |
self.image_size = self.model.config.force_image_size
|
221 |
self.context_len = context_len
|
222 |
self.register_to_controller()
|
223 |
self.heart_beat_thread = threading.Thread(
|
224 |
+
target=heart_beat_worker, args=(self,)
|
225 |
+
)
|
226 |
self.heart_beat_thread.start()
|
227 |
|
228 |
+
def reload_model(self):
|
229 |
+
del self.model
|
230 |
+
torch.cuda.empty_cache()
|
231 |
+
if self.device == "auto":
|
232 |
+
device_map = split_model(self.model_name)
|
233 |
+
self.model = AutoModel.from_pretrained(
|
234 |
+
self.model_path,
|
235 |
+
load_in_8bit=self.load_8bit,
|
236 |
+
torch_dtype=torch.bfloat16,
|
237 |
+
device_map=device_map,
|
238 |
+
trust_remote_code=True,
|
239 |
+
).eval()
|
240 |
+
else:
|
241 |
+
self.model = AutoModel.from_pretrained(
|
242 |
+
self.model_path,
|
243 |
+
load_in_8bit=self.load_8bit,
|
244 |
+
torch_dtype=torch.bfloat16,
|
245 |
+
trust_remote_code=True,
|
246 |
+
).eval()
|
247 |
+
if not self.load_8bit and not self.device == "auto":
|
248 |
+
self.model = self.model.cuda()
|
249 |
+
|
250 |
def register_to_controller(self):
|
251 |
+
logger.info("Register to controller")
|
252 |
|
253 |
+
url = self.controller_addr + "/register_worker"
|
254 |
data = {
|
255 |
+
"worker_name": self.worker_addr,
|
256 |
+
"check_heart_beat": True,
|
257 |
+
"worker_status": self.get_status(),
|
258 |
}
|
259 |
r = requests.post(url, json=data)
|
260 |
assert r.status_code == 200
|
261 |
|
262 |
def send_heart_beat(self):
|
263 |
+
logger.info(
|
264 |
+
f"Send heart beat. Models: {[self.model_name]}. "
|
265 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
266 |
+
f"global_counter: {global_counter}"
|
267 |
+
)
|
268 |
|
269 |
+
url = self.controller_addr + "/receive_heart_beat"
|
270 |
|
271 |
while True:
|
272 |
try:
|
273 |
+
ret = requests.post(
|
274 |
+
url,
|
275 |
+
json={
|
276 |
+
"worker_name": self.worker_addr,
|
277 |
+
"queue_length": self.get_queue_length(),
|
278 |
+
},
|
279 |
+
timeout=5,
|
280 |
+
)
|
281 |
+
exist = ret.json()["exist"]
|
282 |
break
|
283 |
except requests.exceptions.RequestException as e:
|
284 |
+
logger.error(f"heart beat error: {e}")
|
285 |
time.sleep(5)
|
286 |
|
287 |
if not exist:
|
|
|
291 |
if model_semaphore is None:
|
292 |
return 0
|
293 |
else:
|
294 |
+
return (
|
295 |
+
args.limit_model_concurrency
|
296 |
+
- model_semaphore._value
|
297 |
+
+ (
|
298 |
+
len(model_semaphore._waiters)
|
299 |
+
if model_semaphore._waiters is not None
|
300 |
+
else 0
|
301 |
+
)
|
302 |
+
)
|
303 |
|
304 |
def get_status(self):
|
305 |
return {
|
306 |
+
"model_names": [self.model_name],
|
307 |
+
"speed": 1,
|
308 |
+
"queue_length": self.get_queue_length(),
|
309 |
}
|
310 |
|
311 |
+
@spaces.GPU
|
312 |
@torch.inference_mode()
|
313 |
def generate_stream(self, params):
|
314 |
+
system_message = params["prompt"][0]["content"]
|
315 |
+
send_messages = params["prompt"][1:]
|
316 |
+
max_input_tiles = params["max_input_tiles"]
|
317 |
+
temperature = params["temperature"]
|
318 |
+
top_p = params["top_p"]
|
319 |
+
max_new_tokens = params["max_new_tokens"]
|
320 |
+
repetition_penalty = params["repetition_penalty"]
|
321 |
do_sample = True if temperature > 0.0 else False
|
322 |
|
323 |
+
global_image_cnt = 0
|
324 |
history, pil_images, max_input_tile_list = [], [], []
|
325 |
for message in send_messages:
|
326 |
+
if message["role"] == "user":
|
327 |
+
prefix = ""
|
328 |
+
if "image" in message:
|
329 |
max_input_tile_temp = []
|
330 |
+
for image_str in message["image"]:
|
331 |
pil_images.append(load_image_from_base64(image_str))
|
332 |
+
prefix += f"Image-{global_image_cnt + 1}: <image>\n\n"
|
333 |
global_image_cnt += 1
|
334 |
+
max_input_tile_temp.append(
|
335 |
+
max(1, max_input_tiles // len(message["image"]))
|
336 |
+
)
|
337 |
if len(max_input_tile_temp) > 0:
|
338 |
max_input_tile_list.append(max_input_tile_temp)
|
339 |
+
content = prefix + message["content"]
|
340 |
+
history.append(
|
341 |
+
[
|
342 |
+
content,
|
343 |
+
]
|
344 |
+
)
|
345 |
else:
|
346 |
+
history[-1].append(message["content"])
|
347 |
question, history = history[-1][0], history[:-1]
|
348 |
|
349 |
+
if global_image_cnt == 1:
|
350 |
+
question = question.replace("Image-1: <image>\n\n", "<image>\n")
|
351 |
+
history = [
|
352 |
+
[item[0].replace("Image-1: <image>\n\n", "<image>\n"), item[1]]
|
353 |
+
for item in history
|
354 |
+
]
|
355 |
+
|
356 |
# Create a new list to store processed sublists
|
357 |
flattened_list = []
|
358 |
# Iterate through all but the last sublist in max_input_tile_list and process them
|
359 |
for sublist in max_input_tile_list[:-1]:
|
360 |
+
processed_sublist = [1] * len(
|
361 |
+
sublist
|
362 |
+
) # Change each element in the sublist to 1
|
363 |
+
flattened_list.extend(
|
364 |
+
processed_sublist
|
365 |
+
) # Flatten the processed sublist and add to the new list
|
366 |
# If max_input_tile_list is not empty, add the last sublist to the new list
|
367 |
if max_input_tile_list:
|
368 |
flattened_list.extend(max_input_tile_list[-1])
|
369 |
max_input_tile_list = flattened_list
|
370 |
+
assert len(max_input_tile_list) == len(
|
371 |
+
pil_images
|
372 |
+
), "The number of max_input_tile_list and pil_images should be the same."
|
373 |
|
374 |
old_system_message = self.model.system_message
|
375 |
self.model.system_message = system_message
|
376 |
image_tiles = []
|
377 |
transform = build_transform(input_size=self.image_size)
|
378 |
if len(pil_images) > 0:
|
379 |
+
for current_max_input_tiles, pil_image in zip(
|
380 |
+
max_input_tile_list, pil_images
|
381 |
+
):
|
382 |
if self.model.config.dynamic_image_size:
|
383 |
tiles = dynamic_preprocess(
|
384 |
+
pil_image,
|
385 |
+
image_size=self.image_size,
|
386 |
+
max_num=current_max_input_tiles,
|
387 |
+
use_thumbnail=self.model.config.use_thumbnail,
|
388 |
+
)
|
389 |
else:
|
390 |
tiles = [pil_image]
|
391 |
image_tiles += tiles
|
392 |
pixel_values = [transform(item) for item in image_tiles]
|
393 |
+
pixel_values = torch.stack(pixel_values).to(
|
394 |
+
self.model.device, dtype=torch.bfloat16
|
395 |
+
)
|
396 |
+
logger.info(f"Split images to {pixel_values.shape}")
|
397 |
else:
|
398 |
pixel_values = None
|
399 |
|
400 |
+
streamer = TextIteratorStreamer(
|
401 |
+
self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
|
402 |
+
)
|
403 |
generation_config = dict(
|
404 |
num_beams=1,
|
405 |
max_new_tokens=max_new_tokens,
|
|
|
410 |
top_p=top_p,
|
411 |
streamer=streamer,
|
412 |
)
|
413 |
+
logger.info(f"Generation config: {generation_config}")
|
414 |
+
|
415 |
+
thread = Thread(
|
416 |
+
target=self.model.chat,
|
417 |
+
kwargs=dict(
|
418 |
tokenizer=self.tokenizer,
|
419 |
pixel_values=pixel_values,
|
420 |
question=question,
|
421 |
history=history,
|
422 |
return_history=False,
|
423 |
generation_config=generation_config,
|
424 |
+
),
|
425 |
+
)
|
426 |
+
thread.start()
|
427 |
+
|
428 |
+
generated_text = ""
|
429 |
+
for new_text in streamer:
|
430 |
+
generated_text += new_text
|
431 |
+
if generated_text.endswith(self.model.conv_template.sep):
|
432 |
+
generated_text = generated_text[: -len(self.model.conv_template.sep)]
|
433 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
434 |
+
logger.info(
|
435 |
+
f"max_input_tile_list: {max_input_tile_list}, history: {history}, "
|
436 |
+
f"question: {question}, answer: {generated_text}"
|
437 |
+
)
|
438 |
+
self.model.system_message = old_system_message
|
439 |
|
440 |
def generate_stream_gate(self, params):
|
441 |
try:
|
442 |
for x in self.generate_stream(params):
|
443 |
yield x
|
444 |
except ValueError as e:
|
445 |
+
print("Caught ValueError:", e)
|
446 |
+
traceback.print_exc()
|
447 |
ret = {
|
448 |
+
"text": server_error_msg,
|
449 |
+
"error_code": 1,
|
450 |
}
|
451 |
+
yield json.dumps(ret).encode() + b"\0"
|
452 |
except torch.cuda.CudaError as e:
|
453 |
+
traceback.print_exc()
|
454 |
+
print("Caught torch.cuda.CudaError:", e)
|
455 |
ret = {
|
456 |
+
"text": server_error_msg,
|
457 |
+
"error_code": 1,
|
458 |
}
|
459 |
+
yield json.dumps(ret).encode() + b"\0"
|
460 |
except Exception as e:
|
461 |
+
traceback.print_exc()
|
462 |
+
print("Caught Unknown Error", e)
|
463 |
ret = {
|
464 |
+
"text": server_error_msg,
|
465 |
+
"error_code": 1,
|
466 |
}
|
467 |
+
yield json.dumps(ret).encode() + b"\0"
|
468 |
|
469 |
|
470 |
app = FastAPI()
|
|
|
476 |
fn()
|
477 |
|
478 |
|
479 |
+
@app.post("/worker_generate_stream")
|
480 |
async def generate_stream(request: Request):
|
481 |
global model_semaphore, global_counter
|
482 |
global_counter += 1
|
|
|
488 |
worker.send_heart_beat()
|
489 |
generator = worker.generate_stream_gate(params)
|
490 |
background_tasks = BackgroundTasks()
|
491 |
+
background_tasks.add_task(
|
492 |
+
partial(release_model_semaphore, fn=worker.send_heart_beat)
|
493 |
+
)
|
494 |
return StreamingResponse(generator, background=background_tasks)
|
495 |
|
496 |
|
497 |
+
@app.post("/worker_get_status")
|
498 |
async def get_status(request: Request):
|
499 |
return worker.get_status()
|
500 |
|
501 |
|
502 |
+
if __name__ == "__main__":
|
503 |
parser = argparse.ArgumentParser()
|
504 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
505 |
+
parser.add_argument("--port", type=int, default=21002)
|
506 |
+
parser.add_argument("--worker-url", type=str, default="http://localhost")
|
507 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
508 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
509 |
+
parser.add_argument("--model-name", type=str)
|
510 |
+
parser.add_argument("--device", type=str, default="cuda")
|
511 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
512 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
513 |
+
parser.add_argument("--load-8bit", action="store_true")
|
514 |
args = parser.parse_args()
|
515 |
+
logger.info(f"args: {args}")
|
516 |
+
|
517 |
+
worker = ModelWorker(
|
518 |
+
args.controller_url,
|
519 |
+
args.worker_url + f":{args.port}",
|
520 |
+
worker_id,
|
521 |
+
args.model_path,
|
522 |
+
args.model_name,
|
523 |
+
args.load_8bit,
|
524 |
+
args.device,
|
525 |
+
)
|
526 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
requirements.txt
CHANGED
@@ -1,4 +1,14 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.29.2
|
2 |
+
fastapi==0.111.1
|
3 |
+
filelock==3.15.4
|
4 |
+
fire==0.6.0
|
5 |
+
gradio==4.38.1
|
6 |
+
numpy==2.0.1
|
7 |
+
Pillow==10.4.0
|
8 |
+
pydantic==2.8.2
|
9 |
+
Requests==2.32.3
|
10 |
+
spaces==0.28.3
|
11 |
+
torch==2.0.1
|
12 |
+
torchvision==0.15.2
|
13 |
+
transformers==4.37.2
|
14 |
+
uvicorn==0.30.3
|
utils.py
CHANGED
@@ -1,13 +1,22 @@
|
|
|
|
1 |
import logging
|
2 |
import logging.handlers
|
3 |
import os
|
4 |
import sys
|
5 |
-
|
|
|
|
|
|
|
6 |
import requests
|
7 |
from constants import LOGDIR
|
|
|
8 |
|
9 |
-
server_error_msg =
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
handler = None
|
13 |
|
@@ -16,8 +25,8 @@ def build_logger(logger_name, logger_filename):
|
|
16 |
global handler
|
17 |
|
18 |
formatter = logging.Formatter(
|
19 |
-
fmt=
|
20 |
-
datefmt=
|
21 |
)
|
22 |
|
23 |
# Set the format of root handlers
|
@@ -26,12 +35,12 @@ def build_logger(logger_name, logger_filename):
|
|
26 |
logging.getLogger().handlers[0].setFormatter(formatter)
|
27 |
|
28 |
# Redirect stdout and stderr to loggers
|
29 |
-
stdout_logger = logging.getLogger(
|
30 |
stdout_logger.setLevel(logging.INFO)
|
31 |
sl = StreamToLogger(stdout_logger, logging.INFO)
|
32 |
sys.stdout = sl
|
33 |
|
34 |
-
stderr_logger = logging.getLogger(
|
35 |
stderr_logger.setLevel(logging.ERROR)
|
36 |
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
37 |
sys.stderr = sl
|
@@ -45,7 +54,8 @@ def build_logger(logger_name, logger_filename):
|
|
45 |
os.makedirs(LOGDIR, exist_ok=True)
|
46 |
filename = os.path.join(LOGDIR, logger_filename)
|
47 |
handler = logging.handlers.TimedRotatingFileHandler(
|
48 |
-
filename, when=
|
|
|
49 |
handler.setFormatter(formatter)
|
50 |
|
51 |
for name, item in logging.root.manager.loggerDict.items():
|
@@ -59,33 +69,34 @@ class StreamToLogger(object):
|
|
59 |
"""
|
60 |
Fake file-like stream object that redirects writes to a logger instance.
|
61 |
"""
|
|
|
62 |
def __init__(self, logger, log_level=logging.INFO):
|
63 |
self.terminal = sys.stdout
|
64 |
self.logger = logger
|
65 |
self.log_level = log_level
|
66 |
-
self.linebuf =
|
67 |
|
68 |
def __getattr__(self, attr):
|
69 |
return getattr(self.terminal, attr)
|
70 |
|
71 |
def write(self, buf):
|
72 |
temp_linebuf = self.linebuf + buf
|
73 |
-
self.linebuf =
|
74 |
for line in temp_linebuf.splitlines(True):
|
75 |
# From the io.TextIOWrapper docs:
|
76 |
# On output, if newline is None, any '\n' characters written
|
77 |
# are translated to the system default line separator.
|
78 |
# By default sys.stdout.write() expects '\n' newlines and then
|
79 |
# translates them so this is still cross platform.
|
80 |
-
if line[-1] ==
|
81 |
self.logger.log(self.log_level, line.rstrip())
|
82 |
else:
|
83 |
self.linebuf += line
|
84 |
|
85 |
def flush(self):
|
86 |
-
if self.linebuf !=
|
87 |
self.logger.log(self.log_level, self.linebuf.rstrip())
|
88 |
-
self.linebuf =
|
89 |
|
90 |
|
91 |
def disable_torch_init():
|
@@ -93,23 +104,26 @@ def disable_torch_init():
|
|
93 |
Disable the redundant torch default initialization to accelerate model creation.
|
94 |
"""
|
95 |
import torch
|
96 |
-
|
97 |
-
setattr(torch.nn.
|
|
|
98 |
|
99 |
|
100 |
def violates_moderation(text):
|
101 |
"""
|
102 |
Check whether the text violates OpenAI moderation API.
|
103 |
"""
|
104 |
-
url =
|
105 |
-
headers = {
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
110 |
try:
|
111 |
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
112 |
-
flagged = ret.json()[
|
113 |
except requests.exceptions.RequestException as e:
|
114 |
flagged = False
|
115 |
except KeyError as e:
|
@@ -120,5 +134,30 @@ def violates_moderation(text):
|
|
120 |
|
121 |
def pretty_print_semaphore(semaphore):
|
122 |
if semaphore is None:
|
123 |
-
return
|
124 |
-
return f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ast import Dict
|
2 |
import logging
|
3 |
import logging.handlers
|
4 |
import os
|
5 |
import sys
|
6 |
+
import base64
|
7 |
+
from PIL import Image
|
8 |
+
from io import BytesIO
|
9 |
+
import json
|
10 |
import requests
|
11 |
from constants import LOGDIR
|
12 |
+
import datetime
|
13 |
|
14 |
+
server_error_msg = (
|
15 |
+
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
16 |
+
)
|
17 |
+
moderation_msg = (
|
18 |
+
"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
19 |
+
)
|
20 |
|
21 |
handler = None
|
22 |
|
|
|
25 |
global handler
|
26 |
|
27 |
formatter = logging.Formatter(
|
28 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
30 |
)
|
31 |
|
32 |
# Set the format of root handlers
|
|
|
35 |
logging.getLogger().handlers[0].setFormatter(formatter)
|
36 |
|
37 |
# Redirect stdout and stderr to loggers
|
38 |
+
stdout_logger = logging.getLogger("stdout")
|
39 |
stdout_logger.setLevel(logging.INFO)
|
40 |
sl = StreamToLogger(stdout_logger, logging.INFO)
|
41 |
sys.stdout = sl
|
42 |
|
43 |
+
stderr_logger = logging.getLogger("stderr")
|
44 |
stderr_logger.setLevel(logging.ERROR)
|
45 |
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
46 |
sys.stderr = sl
|
|
|
54 |
os.makedirs(LOGDIR, exist_ok=True)
|
55 |
filename = os.path.join(LOGDIR, logger_filename)
|
56 |
handler = logging.handlers.TimedRotatingFileHandler(
|
57 |
+
filename, when="D", utc=True
|
58 |
+
)
|
59 |
handler.setFormatter(formatter)
|
60 |
|
61 |
for name, item in logging.root.manager.loggerDict.items():
|
|
|
69 |
"""
|
70 |
Fake file-like stream object that redirects writes to a logger instance.
|
71 |
"""
|
72 |
+
|
73 |
def __init__(self, logger, log_level=logging.INFO):
|
74 |
self.terminal = sys.stdout
|
75 |
self.logger = logger
|
76 |
self.log_level = log_level
|
77 |
+
self.linebuf = ""
|
78 |
|
79 |
def __getattr__(self, attr):
|
80 |
return getattr(self.terminal, attr)
|
81 |
|
82 |
def write(self, buf):
|
83 |
temp_linebuf = self.linebuf + buf
|
84 |
+
self.linebuf = ""
|
85 |
for line in temp_linebuf.splitlines(True):
|
86 |
# From the io.TextIOWrapper docs:
|
87 |
# On output, if newline is None, any '\n' characters written
|
88 |
# are translated to the system default line separator.
|
89 |
# By default sys.stdout.write() expects '\n' newlines and then
|
90 |
# translates them so this is still cross platform.
|
91 |
+
if line[-1] == "\n":
|
92 |
self.logger.log(self.log_level, line.rstrip())
|
93 |
else:
|
94 |
self.linebuf += line
|
95 |
|
96 |
def flush(self):
|
97 |
+
if self.linebuf != "":
|
98 |
self.logger.log(self.log_level, self.linebuf.rstrip())
|
99 |
+
self.linebuf = ""
|
100 |
|
101 |
|
102 |
def disable_torch_init():
|
|
|
104 |
Disable the redundant torch default initialization to accelerate model creation.
|
105 |
"""
|
106 |
import torch
|
107 |
+
|
108 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
109 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
110 |
|
111 |
|
112 |
def violates_moderation(text):
|
113 |
"""
|
114 |
Check whether the text violates OpenAI moderation API.
|
115 |
"""
|
116 |
+
url = "https://api.openai.com/v1/moderations"
|
117 |
+
headers = {
|
118 |
+
"Content-Type": "application/json",
|
119 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
|
120 |
+
}
|
121 |
+
text = text.replace("\n", "")
|
122 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
123 |
+
data = data.encode("utf-8")
|
124 |
try:
|
125 |
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
126 |
+
flagged = ret.json()["results"][0]["flagged"]
|
127 |
except requests.exceptions.RequestException as e:
|
128 |
flagged = False
|
129 |
except KeyError as e:
|
|
|
134 |
|
135 |
def pretty_print_semaphore(semaphore):
|
136 |
if semaphore is None:
|
137 |
+
return "None"
|
138 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
139 |
+
|
140 |
+
|
141 |
+
def load_image_from_base64(image):
|
142 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
143 |
+
|
144 |
+
|
145 |
+
def get_log_filename():
|
146 |
+
t = datetime.datetime.now()
|
147 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
148 |
+
return name
|
149 |
+
|
150 |
+
|
151 |
+
def data_wrapper(data):
|
152 |
+
if isinstance(data, bytes):
|
153 |
+
return data
|
154 |
+
elif isinstance(data, Image.Image):
|
155 |
+
buffered = BytesIO()
|
156 |
+
data.save(buffered, format="PNG")
|
157 |
+
return buffered.getvalue()
|
158 |
+
elif isinstance(data, str):
|
159 |
+
return data.encode()
|
160 |
+
elif isinstance(data, Dict):
|
161 |
+
return json.dumps(data).encode()
|
162 |
+
else:
|
163 |
+
raise ValueError(f"Unsupported data type: {type(data)}")
|