AnwenHu commited on
Commit
beee9d0
1 Parent(s): 037fee8

Upload model_worker.py

Browse files
Files changed (1) hide show
  1. model_worker.py +174 -0
model_worker.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ from functools import partial
17
+
18
+ from mplug_docowl.utils import (build_logger, server_error_msg,
19
+ pretty_print_semaphore)
20
+
21
+ from mplug_docowl.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,WORKER_HEART_BEAT_INTERVAL
22
+ from mplug_docowl.conversation import conv_templates, SeparatorStyle
23
+ from mplug_docowl.model.builder import load_pretrained_model
24
+ from mplug_docowl.mm_utils import load_image_from_base64, process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
25
+ from mplug_docowl.processor import DocProcessor
26
+
27
+
28
+ from transformers import TextIteratorStreamer
29
+ from threading import Thread
30
+ from icecream import ic
31
+
32
+
33
+ GB = 1 << 30
34
+
35
+ worker_id = str(uuid.uuid4())[:6]
36
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
37
+ global_counter = 0
38
+
39
+ model_semaphore = None
40
+
41
+
42
+ def heart_beat_worker(controller):
43
+
44
+ while True:
45
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
46
+ controller.send_heart_beat()
47
+
48
+
49
+ class ModelWorker:
50
+ def __init__(self,
51
+ model_path, model_base, model_name,
52
+ resolution, anchors, add_global_img,
53
+ load_8bit, load_4bit, device):
54
+
55
+ if model_path.endswith("/"):
56
+ model_path = model_path[:-1]
57
+
58
+ self.model_name = get_model_name_from_path(model_path)
59
+
60
+ self.device = device
61
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
62
+
63
+ self.tokenizer, self.model, _, self.context_len = load_pretrained_model(
64
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
65
+
66
+ self.resolution=resolution
67
+ self.token_num_each_img = (self.resolution/14)*(self.resolution/14)/self.model.get_model().vision2text.conv_patch
68
+ self.doc_image_processor = DocProcessor(image_size=resolution, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
69
+
70
+
71
+ self.is_multimodal = True
72
+
73
+
74
+ @torch.inference_mode()
75
+ def generate_stream(self, params):
76
+ tokenizer, model = self.tokenizer, self.model
77
+
78
+ prompt = params["prompt"]
79
+ ori_prompt = prompt
80
+ images = params.get("images", None)
81
+ num_image_tokens = 0
82
+ if images is not None and len(images) > 0 and self.is_multimodal:
83
+ if len(images) > 0:
84
+
85
+ images = [load_image_from_base64(image) for image in images]
86
+ # docowl only support 1 image, so only keep the last image
87
+ image = images[-1]
88
+ assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
89
+
90
+ images, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
91
+ images = images.to(self.model.device, dtype=torch.float16)
92
+ patch_positions = patch_positions.to(self.model.device)
93
+
94
+ replace_token = DEFAULT_IMAGE_TOKEN
95
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
96
+ num_image_tokens = prompt.count(replace_token) * (self.token_num_each_img+1)
97
+ else:
98
+ images = None
99
+ patch_positions = None
100
+ image_args = {"images": images, "patch_positions":patch_positions}
101
+ else:
102
+ images = None
103
+ image_args = {}
104
+
105
+ temperature = float(params.get("temperature", 1.0))
106
+ top_p = float(params.get("top_p", 1.0))
107
+ # max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
108
+ max_context_length = 4096
109
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
110
+ stop_str = params.get("stop", None)
111
+ # do_sample = True if temperature > 0.001 else False
112
+ do_sample = False
113
+
114
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
115
+ keywords = [stop_str]
116
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
117
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
118
+
119
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
120
+ ic(max_context_length, input_ids.shape[-1], num_image_tokens, max_new_tokens)
121
+
122
+ if max_new_tokens < 1:
123
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode()
124
+ return
125
+
126
+ thread = Thread(target=model.generate, kwargs=dict(
127
+ inputs=input_ids,
128
+ do_sample=do_sample,
129
+ temperature=temperature,
130
+ # top_p=top_p,
131
+ max_new_tokens=max_new_tokens,
132
+ streamer=streamer,
133
+ stopping_criteria=[stopping_criteria],
134
+ use_cache=True,
135
+ **image_args
136
+ ))
137
+ thread.start()
138
+
139
+ generated_text = ori_prompt
140
+ for new_text in streamer:
141
+ generated_text += new_text
142
+ if generated_text.endswith(stop_str):
143
+ generated_text = generated_text[:-len(stop_str)]
144
+ # yield json.dumps({"text": generated_text, "error_code": 0}).encode()
145
+ # replace < > to [ ] to avoide <doc>,<md>,<ocr>,<bbox> are removed by web code
146
+ yield json.dumps({"text": generated_text.replace('<','[').replace('>',']'), "error_code": 0}).encode()
147
+
148
+
149
+
150
+ def generate_stream_gate(self, params):
151
+ try:
152
+ for x in self.generate_stream(params):
153
+ yield x
154
+ except ValueError as e:
155
+ print("Caught ValueError:", e)
156
+ ret = {
157
+ "text": server_error_msg,
158
+ "error_code": 1,
159
+ }
160
+ yield json.dumps(ret).encode()
161
+ except torch.cuda.CudaError as e:
162
+ print("Caught torch.cuda.CudaError:", e)
163
+ ret = {
164
+ "text": server_error_msg,
165
+ "error_code": 1,
166
+ }
167
+ yield json.dumps(ret).encode()
168
+ except Exception as e:
169
+ print("Caught Unknown Error", e)
170
+ ret = {
171
+ "text": server_error_msg,
172
+ "error_code": 1,
173
+ }
174
+ yield json.dumps(ret).encode()