baohuynhbk14 commited on
Commit
3c3eadf
·
1 Parent(s): b3f697c

Refactor API and controller files; remove unused code and update imports

Browse files
Files changed (4) hide show
  1. api.py +1 -48
  2. controller.py +1 -1
  3. gradio_web_server.py +0 -87
  4. sd_worker.py +0 -58
api.py CHANGED
@@ -30,51 +30,4 @@ def get_selected_worker_ip(controller_url, selected_model):
30
  def pil_image_to_base64(image):
31
  buffered = BytesIO()
32
  image.save(buffered, format='PNG')
33
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
34
-
35
-
36
- controller_url = 'http://10.140.60.209:10075'
37
- model_list = get_model_list(controller_url)
38
- print(f'Model list: {model_list}')
39
-
40
- selected_model = 'InternVL2-1B'
41
- worker_addr = get_selected_worker_ip(controller_url, selected_model)
42
- print(f'model_name: {selected_model}, worker_addr: {worker_addr}')
43
-
44
-
45
- # 多轮/多图对话请把数据组织成以下格式:
46
- # send_messages = [{'role': 'system', 'content': system_message}]
47
- # send_messages.append({'role': 'user', 'content': 'question1 to image1', 'image': [pil_image_to_base64(image)]})
48
- # send_messages.append({'role': 'assistant', 'content': 'answer1'})
49
- # send_messages.append({'role': 'user', 'content': 'question2 to image2', 'image': [pil_image_to_base64(image)]})
50
- # send_messages.append({'role': 'assistant', 'content': 'answer2'})
51
- # send_messages.append({'role': 'user', 'content': 'question3 to image1 & 2', 'image': []})
52
-
53
- image = Image.open('image1.jpg')
54
- print(f'Loading image, size: {image.size}')
55
- system_message = """我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。
56
- 请尽可能详细地回答用户的问题。"""
57
- send_messages = [{'role': 'system', 'content': system_message}]
58
- send_messages.append({'role': 'user', 'content': 'describe this image in detail', 'image': [pil_image_to_base64(image)]})
59
-
60
- pload = {
61
- 'model': selected_model,
62
- 'prompt': send_messages,
63
- 'temperature': 0.8,
64
- 'top_p': 0.7,
65
- 'max_new_tokens': 2048,
66
- 'max_input_tiles': 12,
67
- 'repetition_penalty': 1.0,
68
- }
69
- headers = {'User-Agent': 'InternVL-Chat Client'}
70
- response = requests.post(worker_addr + '/worker_generate_stream',
71
- headers=headers, json=pload, stream=True, timeout=10)
72
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\0'):
73
- if chunk:
74
- data = json.loads(chunk.decode())
75
- if data['error_code'] == 0:
76
- output = data['text'] # 这里是流式输出
77
- else:
78
- output = data['text'] + f" (error_code: {data['error_code']})"
79
- # 完整的输出
80
- print(output)
 
30
  def pil_image_to_base64(image):
31
  buffered = BytesIO()
32
  image.save(buffered, format='PNG')
33
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
controller.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
  import requests
16
  import uvicorn
17
  from fastapi import FastAPI, Request
18
- from fastapi.responses import StreamingResponse
19
  from utils import build_logger, server_error_msg
20
 
21
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
 
15
  import requests
16
  import uvicorn
17
  from fastapi import FastAPI, Request
18
+ from starlette.responses import StreamingResponse
19
  from utils import build_logger, server_error_msg
20
 
21
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
gradio_web_server.py CHANGED
@@ -90,81 +90,6 @@ def init_state(state=None):
90
  del state
91
  return Conversation()
92
 
93
-
94
- def find_bounding_boxes(state, response):
95
- pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>")
96
- matches = pattern.findall(response)
97
- results = []
98
- for match in matches:
99
- results.append((match[0], eval(match[1])))
100
- returned_image = None
101
- latest_image = state.get_images(source=state.USER)[-1]
102
- returned_image = latest_image.copy()
103
- width, height = returned_image.size
104
- draw = ImageDraw.Draw(returned_image)
105
- for result in results:
106
- line_width = max(1, int(min(width, height) / 200))
107
- random_color = (
108
- random.randint(0, 128),
109
- random.randint(0, 128),
110
- random.randint(0, 128),
111
- )
112
- category_name, coordinates = result
113
- coordinates = [
114
- (
115
- float(x[0]) / 1000,
116
- float(x[1]) / 1000,
117
- float(x[2]) / 1000,
118
- float(x[3]) / 1000,
119
- )
120
- for x in coordinates
121
- ]
122
- coordinates = [
123
- (
124
- int(x[0] * width),
125
- int(x[1] * height),
126
- int(x[2] * width),
127
- int(x[3] * height),
128
- )
129
- for x in coordinates
130
- ]
131
- for box in coordinates:
132
- draw.rectangle(box, outline=random_color, width=line_width)
133
- font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2))
134
- text_size = font.getbbox(category_name)
135
- text_width, text_height = (
136
- text_size[2] - text_size[0],
137
- text_size[3] - text_size[1],
138
- )
139
- text_position = (box[0], max(0, box[1] - text_height))
140
- draw.rectangle(
141
- [
142
- text_position,
143
- (text_position[0] + text_width, text_position[1] + text_height),
144
- ],
145
- fill=random_color,
146
- )
147
- draw.text(text_position, category_name, fill="white", font=font)
148
- return returned_image if len(matches) > 0 else None
149
-
150
-
151
- def query_image_generation(response, sd_worker_url, timeout=15):
152
- if not sd_worker_url:
153
- return None
154
- sd_worker_url = f"{sd_worker_url}/generate_image/"
155
- pattern = r"```drawing-instruction\n(.*?)\n```"
156
- match = re.search(pattern, response, re.DOTALL)
157
- if match:
158
- payload = {"caption": match.group(1)}
159
- print("drawing-instruction:", payload)
160
- response = requests.post(sd_worker_url, json=payload, timeout=timeout)
161
- response.raise_for_status() # 检查HTTP请求是否成功
162
- image = Image.open(BytesIO(response.content))
163
- return image
164
- else:
165
- return None
166
-
167
-
168
  def load_demo(url_params, request: gr.Request = None):
169
  if not request:
170
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
@@ -433,16 +358,6 @@ def http_bot(
433
  return
434
 
435
  ai_response = state.return_last_message()
436
- if "<ref>" in ai_response:
437
- returned_image = find_bounding_boxes(state, ai_response)
438
- returned_image = [returned_image] if returned_image else []
439
- state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
440
- if "```drawing-instruction" in ai_response:
441
- returned_image = query_image_generation(
442
- ai_response, sd_worker_url=sd_worker_url
443
- )
444
- returned_image = [returned_image] if returned_image else []
445
- state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
446
 
447
  state.end_of_current_turn()
448
 
@@ -823,7 +738,6 @@ if __name__ == "__main__":
823
  parser.add_argument(
824
  "--model-list-mode", type=str, default="reload", choices=["once", "reload"]
825
  )
826
- parser.add_argument("--sd-worker-url", type=str, default=None)
827
  parser.add_argument("--share", action="store_true")
828
  parser.add_argument("--moderate", action="store_true")
829
  parser.add_argument("--embed", action="store_true")
@@ -837,7 +751,6 @@ if __name__ == "__main__":
837
 
838
  models = get_model_list()
839
 
840
- sd_worker_url = args.sd_worker_url
841
  logger.info(args)
842
  demo = build_demo(args.embed)
843
  demo.queue(api_open=False).launch(
 
90
  del state
91
  return Conversation()
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def load_demo(url_params, request: gr.Request = None):
94
  if not request:
95
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
 
358
  return
359
 
360
  ai_response = state.return_last_message()
 
 
 
 
 
 
 
 
 
 
361
 
362
  state.end_of_current_turn()
363
 
 
738
  parser.add_argument(
739
  "--model-list-mode", type=str, default="reload", choices=["once", "reload"]
740
  )
 
741
  parser.add_argument("--share", action="store_true")
742
  parser.add_argument("--moderate", action="store_true")
743
  parser.add_argument("--embed", action="store_true")
 
751
 
752
  models = get_model_list()
753
 
 
754
  logger.info(args)
755
  demo = build_demo(args.embed)
756
  demo.queue(api_open=False).launch(
sd_worker.py DELETED
@@ -1,58 +0,0 @@
1
- # --------------------------------------------------------
2
- # InternVL
3
- # Copyright (c) 2024 OpenGVLab
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # --------------------------------------------------------
6
-
7
- from io import BytesIO
8
-
9
- import torch
10
- from diffusers import StableDiffusion3Pipeline
11
- from fastapi import FastAPI
12
- from fastapi.responses import Response
13
- from pydantic import BaseModel
14
-
15
- # Initialize pipeline
16
- pipe = StableDiffusion3Pipeline.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers',
17
- torch_dtype=torch.float16)
18
- pipe = pipe.to('cuda')
19
-
20
- # Create a FastAPI application
21
- app = FastAPI()
22
-
23
-
24
- # Define the input data model
25
- class CaptionRequest(BaseModel):
26
- caption: str
27
-
28
-
29
- # Defining API endpoints
30
- @app.post('/generate_image/')
31
- async def generate_image(request: CaptionRequest):
32
- caption = request.caption
33
- negative_prompt = 'blurry, low resolution, artifacts, unnatural, poorly drawn, bad anatomy, out of focus'
34
- image = pipe(
35
- caption,
36
- negative_prompt=negative_prompt,
37
- num_inference_steps=20,
38
- guidance_scale=7.0
39
- ).images[0]
40
-
41
- # Converts an image to a byte stream
42
- img_byte_arr = BytesIO()
43
- image.save(img_byte_arr, format='PNG')
44
- img_byte_arr = img_byte_arr.getvalue()
45
-
46
- return Response(content=img_byte_arr, media_type='image/png')
47
-
48
-
49
- # Run the Uvicorn server
50
- if __name__ == '__main__':
51
- import argparse
52
-
53
- import uvicorn
54
- parser = argparse.ArgumentParser()
55
- parser.add_argument('--port', default=11005, type=int)
56
- args = parser.parse_args()
57
-
58
- uvicorn.run(app, host='0.0.0.0', port=args.port)