Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
3c3eadf
1
Parent(s):
b3f697c
Refactor API and controller files; remove unused code and update imports
Browse files- api.py +1 -48
- controller.py +1 -1
- gradio_web_server.py +0 -87
- sd_worker.py +0 -58
api.py
CHANGED
@@ -30,51 +30,4 @@ def get_selected_worker_ip(controller_url, selected_model):
|
|
30 |
def pil_image_to_base64(image):
|
31 |
buffered = BytesIO()
|
32 |
image.save(buffered, format='PNG')
|
33 |
-
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
34 |
-
|
35 |
-
|
36 |
-
controller_url = 'http://10.140.60.209:10075'
|
37 |
-
model_list = get_model_list(controller_url)
|
38 |
-
print(f'Model list: {model_list}')
|
39 |
-
|
40 |
-
selected_model = 'InternVL2-1B'
|
41 |
-
worker_addr = get_selected_worker_ip(controller_url, selected_model)
|
42 |
-
print(f'model_name: {selected_model}, worker_addr: {worker_addr}')
|
43 |
-
|
44 |
-
|
45 |
-
# 多轮/多图对话请把数据组织成以下格式:
|
46 |
-
# send_messages = [{'role': 'system', 'content': system_message}]
|
47 |
-
# send_messages.append({'role': 'user', 'content': 'question1 to image1', 'image': [pil_image_to_base64(image)]})
|
48 |
-
# send_messages.append({'role': 'assistant', 'content': 'answer1'})
|
49 |
-
# send_messages.append({'role': 'user', 'content': 'question2 to image2', 'image': [pil_image_to_base64(image)]})
|
50 |
-
# send_messages.append({'role': 'assistant', 'content': 'answer2'})
|
51 |
-
# send_messages.append({'role': 'user', 'content': 'question3 to image1 & 2', 'image': []})
|
52 |
-
|
53 |
-
image = Image.open('image1.jpg')
|
54 |
-
print(f'Loading image, size: {image.size}')
|
55 |
-
system_message = """我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。
|
56 |
-
请尽可能详细地回答用户的问题。"""
|
57 |
-
send_messages = [{'role': 'system', 'content': system_message}]
|
58 |
-
send_messages.append({'role': 'user', 'content': 'describe this image in detail', 'image': [pil_image_to_base64(image)]})
|
59 |
-
|
60 |
-
pload = {
|
61 |
-
'model': selected_model,
|
62 |
-
'prompt': send_messages,
|
63 |
-
'temperature': 0.8,
|
64 |
-
'top_p': 0.7,
|
65 |
-
'max_new_tokens': 2048,
|
66 |
-
'max_input_tiles': 12,
|
67 |
-
'repetition_penalty': 1.0,
|
68 |
-
}
|
69 |
-
headers = {'User-Agent': 'InternVL-Chat Client'}
|
70 |
-
response = requests.post(worker_addr + '/worker_generate_stream',
|
71 |
-
headers=headers, json=pload, stream=True, timeout=10)
|
72 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\0'):
|
73 |
-
if chunk:
|
74 |
-
data = json.loads(chunk.decode())
|
75 |
-
if data['error_code'] == 0:
|
76 |
-
output = data['text'] # 这里是流式输出
|
77 |
-
else:
|
78 |
-
output = data['text'] + f" (error_code: {data['error_code']})"
|
79 |
-
# 完整的输出
|
80 |
-
print(output)
|
|
|
30 |
def pil_image_to_base64(image):
|
31 |
buffered = BytesIO()
|
32 |
image.save(buffered, format='PNG')
|
33 |
+
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
controller.py
CHANGED
@@ -15,7 +15,7 @@ import numpy as np
|
|
15 |
import requests
|
16 |
import uvicorn
|
17 |
from fastapi import FastAPI, Request
|
18 |
-
from
|
19 |
from utils import build_logger, server_error_msg
|
20 |
|
21 |
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
|
|
15 |
import requests
|
16 |
import uvicorn
|
17 |
from fastapi import FastAPI, Request
|
18 |
+
from starlette.responses import StreamingResponse
|
19 |
from utils import build_logger, server_error_msg
|
20 |
|
21 |
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
gradio_web_server.py
CHANGED
@@ -90,81 +90,6 @@ def init_state(state=None):
|
|
90 |
del state
|
91 |
return Conversation()
|
92 |
|
93 |
-
|
94 |
-
def find_bounding_boxes(state, response):
|
95 |
-
pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>")
|
96 |
-
matches = pattern.findall(response)
|
97 |
-
results = []
|
98 |
-
for match in matches:
|
99 |
-
results.append((match[0], eval(match[1])))
|
100 |
-
returned_image = None
|
101 |
-
latest_image = state.get_images(source=state.USER)[-1]
|
102 |
-
returned_image = latest_image.copy()
|
103 |
-
width, height = returned_image.size
|
104 |
-
draw = ImageDraw.Draw(returned_image)
|
105 |
-
for result in results:
|
106 |
-
line_width = max(1, int(min(width, height) / 200))
|
107 |
-
random_color = (
|
108 |
-
random.randint(0, 128),
|
109 |
-
random.randint(0, 128),
|
110 |
-
random.randint(0, 128),
|
111 |
-
)
|
112 |
-
category_name, coordinates = result
|
113 |
-
coordinates = [
|
114 |
-
(
|
115 |
-
float(x[0]) / 1000,
|
116 |
-
float(x[1]) / 1000,
|
117 |
-
float(x[2]) / 1000,
|
118 |
-
float(x[3]) / 1000,
|
119 |
-
)
|
120 |
-
for x in coordinates
|
121 |
-
]
|
122 |
-
coordinates = [
|
123 |
-
(
|
124 |
-
int(x[0] * width),
|
125 |
-
int(x[1] * height),
|
126 |
-
int(x[2] * width),
|
127 |
-
int(x[3] * height),
|
128 |
-
)
|
129 |
-
for x in coordinates
|
130 |
-
]
|
131 |
-
for box in coordinates:
|
132 |
-
draw.rectangle(box, outline=random_color, width=line_width)
|
133 |
-
font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2))
|
134 |
-
text_size = font.getbbox(category_name)
|
135 |
-
text_width, text_height = (
|
136 |
-
text_size[2] - text_size[0],
|
137 |
-
text_size[3] - text_size[1],
|
138 |
-
)
|
139 |
-
text_position = (box[0], max(0, box[1] - text_height))
|
140 |
-
draw.rectangle(
|
141 |
-
[
|
142 |
-
text_position,
|
143 |
-
(text_position[0] + text_width, text_position[1] + text_height),
|
144 |
-
],
|
145 |
-
fill=random_color,
|
146 |
-
)
|
147 |
-
draw.text(text_position, category_name, fill="white", font=font)
|
148 |
-
return returned_image if len(matches) > 0 else None
|
149 |
-
|
150 |
-
|
151 |
-
def query_image_generation(response, sd_worker_url, timeout=15):
|
152 |
-
if not sd_worker_url:
|
153 |
-
return None
|
154 |
-
sd_worker_url = f"{sd_worker_url}/generate_image/"
|
155 |
-
pattern = r"```drawing-instruction\n(.*?)\n```"
|
156 |
-
match = re.search(pattern, response, re.DOTALL)
|
157 |
-
if match:
|
158 |
-
payload = {"caption": match.group(1)}
|
159 |
-
print("drawing-instruction:", payload)
|
160 |
-
response = requests.post(sd_worker_url, json=payload, timeout=timeout)
|
161 |
-
response.raise_for_status() # 检查HTTP请求是否成功
|
162 |
-
image = Image.open(BytesIO(response.content))
|
163 |
-
return image
|
164 |
-
else:
|
165 |
-
return None
|
166 |
-
|
167 |
-
|
168 |
def load_demo(url_params, request: gr.Request = None):
|
169 |
if not request:
|
170 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
@@ -433,16 +358,6 @@ def http_bot(
|
|
433 |
return
|
434 |
|
435 |
ai_response = state.return_last_message()
|
436 |
-
if "<ref>" in ai_response:
|
437 |
-
returned_image = find_bounding_boxes(state, ai_response)
|
438 |
-
returned_image = [returned_image] if returned_image else []
|
439 |
-
state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
|
440 |
-
if "```drawing-instruction" in ai_response:
|
441 |
-
returned_image = query_image_generation(
|
442 |
-
ai_response, sd_worker_url=sd_worker_url
|
443 |
-
)
|
444 |
-
returned_image = [returned_image] if returned_image else []
|
445 |
-
state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
|
446 |
|
447 |
state.end_of_current_turn()
|
448 |
|
@@ -823,7 +738,6 @@ if __name__ == "__main__":
|
|
823 |
parser.add_argument(
|
824 |
"--model-list-mode", type=str, default="reload", choices=["once", "reload"]
|
825 |
)
|
826 |
-
parser.add_argument("--sd-worker-url", type=str, default=None)
|
827 |
parser.add_argument("--share", action="store_true")
|
828 |
parser.add_argument("--moderate", action="store_true")
|
829 |
parser.add_argument("--embed", action="store_true")
|
@@ -837,7 +751,6 @@ if __name__ == "__main__":
|
|
837 |
|
838 |
models = get_model_list()
|
839 |
|
840 |
-
sd_worker_url = args.sd_worker_url
|
841 |
logger.info(args)
|
842 |
demo = build_demo(args.embed)
|
843 |
demo.queue(api_open=False).launch(
|
|
|
90 |
del state
|
91 |
return Conversation()
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def load_demo(url_params, request: gr.Request = None):
|
94 |
if not request:
|
95 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
|
|
358 |
return
|
359 |
|
360 |
ai_response = state.return_last_message()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
state.end_of_current_turn()
|
363 |
|
|
|
738 |
parser.add_argument(
|
739 |
"--model-list-mode", type=str, default="reload", choices=["once", "reload"]
|
740 |
)
|
|
|
741 |
parser.add_argument("--share", action="store_true")
|
742 |
parser.add_argument("--moderate", action="store_true")
|
743 |
parser.add_argument("--embed", action="store_true")
|
|
|
751 |
|
752 |
models = get_model_list()
|
753 |
|
|
|
754 |
logger.info(args)
|
755 |
demo = build_demo(args.embed)
|
756 |
demo.queue(api_open=False).launch(
|
sd_worker.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
# --------------------------------------------------------
|
2 |
-
# InternVL
|
3 |
-
# Copyright (c) 2024 OpenGVLab
|
4 |
-
# Licensed under The MIT License [see LICENSE for details]
|
5 |
-
# --------------------------------------------------------
|
6 |
-
|
7 |
-
from io import BytesIO
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from diffusers import StableDiffusion3Pipeline
|
11 |
-
from fastapi import FastAPI
|
12 |
-
from fastapi.responses import Response
|
13 |
-
from pydantic import BaseModel
|
14 |
-
|
15 |
-
# Initialize pipeline
|
16 |
-
pipe = StableDiffusion3Pipeline.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers',
|
17 |
-
torch_dtype=torch.float16)
|
18 |
-
pipe = pipe.to('cuda')
|
19 |
-
|
20 |
-
# Create a FastAPI application
|
21 |
-
app = FastAPI()
|
22 |
-
|
23 |
-
|
24 |
-
# Define the input data model
|
25 |
-
class CaptionRequest(BaseModel):
|
26 |
-
caption: str
|
27 |
-
|
28 |
-
|
29 |
-
# Defining API endpoints
|
30 |
-
@app.post('/generate_image/')
|
31 |
-
async def generate_image(request: CaptionRequest):
|
32 |
-
caption = request.caption
|
33 |
-
negative_prompt = 'blurry, low resolution, artifacts, unnatural, poorly drawn, bad anatomy, out of focus'
|
34 |
-
image = pipe(
|
35 |
-
caption,
|
36 |
-
negative_prompt=negative_prompt,
|
37 |
-
num_inference_steps=20,
|
38 |
-
guidance_scale=7.0
|
39 |
-
).images[0]
|
40 |
-
|
41 |
-
# Converts an image to a byte stream
|
42 |
-
img_byte_arr = BytesIO()
|
43 |
-
image.save(img_byte_arr, format='PNG')
|
44 |
-
img_byte_arr = img_byte_arr.getvalue()
|
45 |
-
|
46 |
-
return Response(content=img_byte_arr, media_type='image/png')
|
47 |
-
|
48 |
-
|
49 |
-
# Run the Uvicorn server
|
50 |
-
if __name__ == '__main__':
|
51 |
-
import argparse
|
52 |
-
|
53 |
-
import uvicorn
|
54 |
-
parser = argparse.ArgumentParser()
|
55 |
-
parser.add_argument('--port', default=11005, type=int)
|
56 |
-
args = parser.parse_args()
|
57 |
-
|
58 |
-
uvicorn.run(app, host='0.0.0.0', port=args.port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|