Spaces:
Runtime error
Runtime error
Upload 52 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mplug_docowl/__init__.py +2 -0
- mplug_docowl/__pycache__/__init__.cpython-310.pyc +0 -0
- mplug_docowl/__pycache__/constants.cpython-310.pyc +0 -0
- mplug_docowl/__pycache__/conversation.cpython-310.pyc +0 -0
- mplug_docowl/__pycache__/mm_utils.cpython-310.pyc +0 -0
- mplug_docowl/__pycache__/processor.cpython-310.pyc +0 -0
- mplug_docowl/__pycache__/utils.cpython-310.pyc +0 -0
- mplug_docowl/constants.py +9 -0
- mplug_docowl/conversation.py +301 -0
- mplug_docowl/local_serve/__init__.py +0 -0
- mplug_docowl/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg +0 -0
- mplug_docowl/local_serve/examples/extreme_ironing.jpg +0 -0
- mplug_docowl/local_serve/local_web_server.py +392 -0
- mplug_docowl/local_serve/model_worker.py +143 -0
- mplug_docowl/mm_utils.py +112 -0
- mplug_docowl/model/__init__.py +2 -0
- mplug_docowl/model/__pycache__/__init__.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/builder.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/configuration_mplug_docowl.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/configuration_mplug_docowl2.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/convert_mplug_docowl2_weight_to_hf.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf_v2.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/modeling_llama2.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/modeling_mplug_docowl.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/modeling_mplug_docowl2.cpython-310.pyc +0 -0
- mplug_docowl/model/__pycache__/visual_encoder.cpython-310.pyc +0 -0
- mplug_docowl/model/builder.py +81 -0
- mplug_docowl/model/configuration_mplug_docowl.py +318 -0
- mplug_docowl/model/convert_mplug_docowl_weight_to_hf.py +319 -0
- mplug_docowl/model/convert_mplug_docowl_weight_to_hf_v2.py +320 -0
- mplug_docowl/model/modeling_attn_mask_utils.py +247 -0
- mplug_docowl/model/modeling_llama2.py +486 -0
- mplug_docowl/model/modeling_mplug_docowl.py +313 -0
- mplug_docowl/model/utils.py +20 -0
- mplug_docowl/model/visual_encoder.py +499 -0
- mplug_docowl/processor.py +219 -0
- mplug_docowl/serve/__init__.py +0 -0
- mplug_docowl/serve/cli.py +120 -0
- mplug_docowl/serve/controller.py +298 -0
- mplug_docowl/serve/examples/Rebecca_(1939_poster)_Small.jpeg +0 -0
- mplug_docowl/serve/examples/extreme_ironing.jpg +0 -0
- mplug_docowl/serve/gradio_web_server.py +460 -0
- mplug_docowl/serve/model_worker.py +342 -0
- mplug_docowl/serve/model_worker_bak.py +278 -0
- mplug_docowl/serve/register_workers.py +26 -0
- mplug_docowl/train/llama_flash_attn_monkey_patch.py +117 -0
- mplug_docowl/train/mplug_owl2_trainer.py +243 -0
- mplug_docowl/train/train.py +801 -0
mplug_docowl/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .model import MPLUGDocOwlLlamaForCausalLM
|
2 |
+
from .processor import DocProcessor
|
mplug_docowl/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (271 Bytes). View file
|
|
mplug_docowl/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (368 Bytes). View file
|
|
mplug_docowl/__pycache__/conversation.cpython-310.pyc
ADDED
Binary file (8.56 kB). View file
|
|
mplug_docowl/__pycache__/mm_utils.cpython-310.pyc
ADDED
Binary file (4.58 kB). View file
|
|
mplug_docowl/__pycache__/processor.cpython-310.pyc
ADDED
Binary file (6.68 kB). View file
|
|
mplug_docowl/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.05 kB). View file
|
|
mplug_docowl/constants.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "./demo_logs"
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
mplug_docowl/conversation.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
from mplug_docowl.constants import DEFAULT_IMAGE_TOKEN
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
TWO_NO_SYS = auto()
|
11 |
+
MPT = auto()
|
12 |
+
PLAIN = auto()
|
13 |
+
LLAMA_2 = auto()
|
14 |
+
|
15 |
+
|
16 |
+
@dataclasses.dataclass
|
17 |
+
class Conversation:
|
18 |
+
"""A class that keeps all conversation history."""
|
19 |
+
system: str
|
20 |
+
roles: List[str]
|
21 |
+
messages: List[List[str]]
|
22 |
+
offset: int
|
23 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
24 |
+
sep: str = "###"
|
25 |
+
sep2: str = None
|
26 |
+
version: str = "Unknown"
|
27 |
+
|
28 |
+
skip_next: bool = False
|
29 |
+
|
30 |
+
def get_prompt(self):
|
31 |
+
messages = self.messages
|
32 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
33 |
+
messages = self.messages.copy()
|
34 |
+
init_role, init_msg = messages[0].copy()
|
35 |
+
# init_msg = init_msg[0].replace("<image>", "").strip()
|
36 |
+
# if 'mmtag' in self.version:
|
37 |
+
# messages[0] = (init_role, init_msg)
|
38 |
+
# messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
39 |
+
# messages.insert(1, (self.roles[1], "Received."))
|
40 |
+
# else:
|
41 |
+
# messages[0] = (init_role, "<image>\n" + init_msg)
|
42 |
+
init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
43 |
+
messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
|
44 |
+
|
45 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
46 |
+
ret = self.system + self.sep
|
47 |
+
for role, message in messages:
|
48 |
+
if message:
|
49 |
+
if type(message) is tuple:
|
50 |
+
message, _, _ = message
|
51 |
+
ret += role + ": " + message + self.sep
|
52 |
+
else:
|
53 |
+
ret += role + ":"
|
54 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
55 |
+
seps = [self.sep, self.sep2]
|
56 |
+
ret = self.system + seps[0]
|
57 |
+
for i, (role, message) in enumerate(messages):
|
58 |
+
if message:
|
59 |
+
if type(message) is tuple:
|
60 |
+
message, _, _ = message
|
61 |
+
ret += role + ": " + message + seps[i % 2]
|
62 |
+
else:
|
63 |
+
ret += role + ":"
|
64 |
+
elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
|
65 |
+
seps = [self.sep, self.sep2]
|
66 |
+
ret = ""
|
67 |
+
for i, (role, message) in enumerate(messages):
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message, _, _ = message
|
71 |
+
ret += role + ": " + message + seps[i % 2]
|
72 |
+
else:
|
73 |
+
ret += role + ":"
|
74 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
75 |
+
ret = self.system + self.sep
|
76 |
+
for role, message in messages:
|
77 |
+
if message:
|
78 |
+
if type(message) is tuple:
|
79 |
+
message, _, _ = message
|
80 |
+
ret += role + message + self.sep
|
81 |
+
else:
|
82 |
+
ret += role
|
83 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
84 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
85 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
86 |
+
ret = ""
|
87 |
+
|
88 |
+
for i, (role, message) in enumerate(messages):
|
89 |
+
if i == 0:
|
90 |
+
assert message, "first message should not be none"
|
91 |
+
assert role == self.roles[0], "first message should come from user"
|
92 |
+
if message:
|
93 |
+
if type(message) is tuple:
|
94 |
+
message, _, _ = message
|
95 |
+
if i == 0: message = wrap_sys(self.system) + message
|
96 |
+
if i % 2 == 0:
|
97 |
+
message = wrap_inst(message)
|
98 |
+
ret += self.sep + message
|
99 |
+
else:
|
100 |
+
ret += " " + message + " " + self.sep2
|
101 |
+
else:
|
102 |
+
ret += ""
|
103 |
+
ret = ret.lstrip(self.sep)
|
104 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
105 |
+
seps = [self.sep, self.sep2]
|
106 |
+
ret = self.system
|
107 |
+
for i, (role, message) in enumerate(messages):
|
108 |
+
if message:
|
109 |
+
if type(message) is tuple:
|
110 |
+
message, _, _ = message
|
111 |
+
ret += message + seps[i % 2]
|
112 |
+
else:
|
113 |
+
ret += ""
|
114 |
+
else:
|
115 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
116 |
+
|
117 |
+
return ret
|
118 |
+
|
119 |
+
def append_message(self, role, message):
|
120 |
+
self.messages.append([role, message])
|
121 |
+
|
122 |
+
def get_images(self, return_pil=False):
|
123 |
+
images = []
|
124 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
125 |
+
if i % 2 == 0:
|
126 |
+
if type(msg) is tuple:
|
127 |
+
import base64
|
128 |
+
from io import BytesIO
|
129 |
+
from PIL import Image
|
130 |
+
msg, image, image_process_mode = msg
|
131 |
+
if image_process_mode == "Pad":
|
132 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
133 |
+
width, height = pil_img.size
|
134 |
+
if width == height:
|
135 |
+
return pil_img
|
136 |
+
elif width > height:
|
137 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
138 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
139 |
+
return result
|
140 |
+
else:
|
141 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
142 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
143 |
+
return result
|
144 |
+
image = expand2square(image)
|
145 |
+
elif image_process_mode in ["Default", "Crop"]:
|
146 |
+
pass
|
147 |
+
elif image_process_mode == "Resize":
|
148 |
+
image = image.resize((336, 336))
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
151 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
152 |
+
aspect_ratio = max_hw / min_hw
|
153 |
+
max_len, min_len = 800, 400
|
154 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
155 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
156 |
+
W, H = image.size
|
157 |
+
if longest_edge != max(image.size):
|
158 |
+
if H > W:
|
159 |
+
H, W = longest_edge, shortest_edge
|
160 |
+
else:
|
161 |
+
H, W = shortest_edge, longest_edge
|
162 |
+
image = image.resize((W, H))
|
163 |
+
if return_pil:
|
164 |
+
images.append(image)
|
165 |
+
else:
|
166 |
+
buffered = BytesIO()
|
167 |
+
image.save(buffered, format="PNG")
|
168 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
169 |
+
images.append(img_b64_str)
|
170 |
+
return images
|
171 |
+
|
172 |
+
def to_gradio_chatbot(self):
|
173 |
+
ret = []
|
174 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
175 |
+
if i % 2 == 0:
|
176 |
+
if type(msg) is tuple:
|
177 |
+
import base64
|
178 |
+
from io import BytesIO
|
179 |
+
msg, image, image_process_mode = msg
|
180 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
181 |
+
aspect_ratio = max_hw / min_hw
|
182 |
+
max_len, min_len = 800, 400
|
183 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
184 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
185 |
+
W, H = image.size
|
186 |
+
if H > W:
|
187 |
+
H, W = longest_edge, shortest_edge
|
188 |
+
else:
|
189 |
+
H, W = shortest_edge, longest_edge
|
190 |
+
image = image.resize((W, H))
|
191 |
+
buffered = BytesIO()
|
192 |
+
image.save(buffered, format="JPEG")
|
193 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
194 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
195 |
+
msg = img_str + msg.replace('<|image|>', '').strip()
|
196 |
+
ret.append([msg, None])
|
197 |
+
else:
|
198 |
+
ret.append([msg, None])
|
199 |
+
else:
|
200 |
+
ret[-1][-1] = msg
|
201 |
+
return ret
|
202 |
+
|
203 |
+
def copy(self):
|
204 |
+
return Conversation(
|
205 |
+
system=self.system,
|
206 |
+
roles=self.roles,
|
207 |
+
messages=[[x, y] for x, y in self.messages],
|
208 |
+
offset=self.offset,
|
209 |
+
sep_style=self.sep_style,
|
210 |
+
sep=self.sep,
|
211 |
+
sep2=self.sep2,
|
212 |
+
version=self.version)
|
213 |
+
|
214 |
+
def dict(self):
|
215 |
+
if len(self.get_images()) > 0:
|
216 |
+
return {
|
217 |
+
"system": self.system,
|
218 |
+
"roles": self.roles,
|
219 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
220 |
+
"offset": self.offset,
|
221 |
+
"sep": self.sep,
|
222 |
+
"sep2": self.sep2,
|
223 |
+
}
|
224 |
+
return {
|
225 |
+
"system": self.system,
|
226 |
+
"roles": self.roles,
|
227 |
+
"messages": self.messages,
|
228 |
+
"offset": self.offset,
|
229 |
+
"sep": self.sep,
|
230 |
+
"sep2": self.sep2,
|
231 |
+
}
|
232 |
+
|
233 |
+
|
234 |
+
conv_vicuna_v0 = Conversation(
|
235 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
236 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
237 |
+
roles=("Human", "Assistant"),
|
238 |
+
messages=(
|
239 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
240 |
+
("Assistant",
|
241 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
242 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
243 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
244 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
245 |
+
"renewable and non-renewable energy sources:\n"
|
246 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
247 |
+
"energy sources are finite and will eventually run out.\n"
|
248 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
249 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
250 |
+
"and other negative effects.\n"
|
251 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
252 |
+
"have lower operational costs than non-renewable sources.\n"
|
253 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
254 |
+
"locations than non-renewable sources.\n"
|
255 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
256 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
257 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
258 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
259 |
+
),
|
260 |
+
offset=2,
|
261 |
+
sep_style=SeparatorStyle.SINGLE,
|
262 |
+
sep="###",
|
263 |
+
)
|
264 |
+
|
265 |
+
conv_vicuna_v1 = Conversation(
|
266 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
267 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
268 |
+
roles=("USER", "ASSISTANT"),
|
269 |
+
version="v1",
|
270 |
+
messages=(),
|
271 |
+
offset=0,
|
272 |
+
sep_style=SeparatorStyle.TWO,
|
273 |
+
sep=" ",
|
274 |
+
sep2="</s>",
|
275 |
+
)
|
276 |
+
|
277 |
+
conv_mplug_owl2 = Conversation(
|
278 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
279 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
280 |
+
roles=("USER", "ASSISTANT"),
|
281 |
+
version="v1",
|
282 |
+
messages=(),
|
283 |
+
offset=0,
|
284 |
+
sep_style=SeparatorStyle.TWO_NO_SYS,
|
285 |
+
sep=" ",
|
286 |
+
sep2="</s>",
|
287 |
+
)
|
288 |
+
|
289 |
+
# default_conversation = conv_vicuna_v1
|
290 |
+
default_conversation = conv_mplug_owl2
|
291 |
+
conv_templates = {
|
292 |
+
"default": conv_vicuna_v0,
|
293 |
+
"v0": conv_vicuna_v0,
|
294 |
+
"v1": conv_vicuna_v1,
|
295 |
+
"vicuna_v1": conv_vicuna_v1,
|
296 |
+
"mplug_owl2": conv_mplug_owl2,
|
297 |
+
}
|
298 |
+
|
299 |
+
|
300 |
+
if __name__ == "__main__":
|
301 |
+
print(default_conversation.get_prompt())
|
mplug_docowl/local_serve/__init__.py
ADDED
File without changes
|
mplug_docowl/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg
ADDED
mplug_docowl/local_serve/examples/extreme_ironing.jpg
ADDED
mplug_docowl/local_serve/local_web_server.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from mplug_owl2.conversation import (default_conversation, conv_templates,
|
11 |
+
SeparatorStyle)
|
12 |
+
from mplug_owl2.constants import LOGDIR
|
13 |
+
from mplug_owl2.utils import (build_logger, server_error_msg,
|
14 |
+
violates_moderation, moderation_msg)
|
15 |
+
from .model_worker import ModelWorker
|
16 |
+
import hashlib
|
17 |
+
|
18 |
+
logger = build_logger("gradio_web_server_local", "gradio_web_server_local.log")
|
19 |
+
|
20 |
+
headers = {"User-Agent": "mPLUG-Owl2 Client"}
|
21 |
+
|
22 |
+
no_change_btn = gr.Button.update()
|
23 |
+
enable_btn = gr.Button.update(interactive=True)
|
24 |
+
disable_btn = gr.Button.update(interactive=False)
|
25 |
+
|
26 |
+
def get_conv_log_filename():
|
27 |
+
t = datetime.datetime.now()
|
28 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
29 |
+
return name
|
30 |
+
|
31 |
+
get_window_url_params = """
|
32 |
+
function() {
|
33 |
+
const params = new URLSearchParams(window.location.search);
|
34 |
+
url_params = Object.fromEntries(params);
|
35 |
+
console.log(url_params);
|
36 |
+
return url_params;
|
37 |
+
}
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def load_demo(url_params, request: gr.Request):
|
42 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
43 |
+
state = default_conversation.copy()
|
44 |
+
return state
|
45 |
+
|
46 |
+
|
47 |
+
def vote_last_response(state, vote_type, request: gr.Request):
|
48 |
+
with open(get_conv_log_filename(), "a") as fout:
|
49 |
+
data = {
|
50 |
+
"tstamp": round(time.time(), 4),
|
51 |
+
"type": vote_type,
|
52 |
+
"state": state.dict(),
|
53 |
+
"ip": request.client.host,
|
54 |
+
}
|
55 |
+
fout.write(json.dumps(data) + "\n")
|
56 |
+
|
57 |
+
|
58 |
+
def upvote_last_response(state, request: gr.Request):
|
59 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
60 |
+
vote_last_response(state, "upvote", request)
|
61 |
+
return ("",) + (disable_btn,) * 3
|
62 |
+
|
63 |
+
|
64 |
+
def downvote_last_response(state, request: gr.Request):
|
65 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
66 |
+
vote_last_response(state, "downvote", request)
|
67 |
+
return ("",) + (disable_btn,) * 3
|
68 |
+
|
69 |
+
|
70 |
+
def flag_last_response(state, request: gr.Request):
|
71 |
+
logger.info(f"flag. ip: {request.client.host}")
|
72 |
+
vote_last_response(state, "flag", request)
|
73 |
+
return ("",) + (disable_btn,) * 3
|
74 |
+
|
75 |
+
|
76 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
77 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
78 |
+
state.messages[-1][-1] = None
|
79 |
+
prev_human_msg = state.messages[-2]
|
80 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
81 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
82 |
+
state.skip_next = False
|
83 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
84 |
+
|
85 |
+
|
86 |
+
def clear_history(request: gr.Request):
|
87 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
88 |
+
state = default_conversation.copy()
|
89 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
90 |
+
|
91 |
+
|
92 |
+
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
93 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
94 |
+
if len(text) <= 0 and image is None:
|
95 |
+
state.skip_next = True
|
96 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
97 |
+
if args.moderate:
|
98 |
+
flagged = violates_moderation(text)
|
99 |
+
if flagged:
|
100 |
+
state.skip_next = True
|
101 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
102 |
+
no_change_btn,) * 5
|
103 |
+
|
104 |
+
text = text[:3584] # Hard cut-off
|
105 |
+
if image is not None:
|
106 |
+
text = text[:3500] # Hard cut-off for images
|
107 |
+
if '<|image|>' not in text:
|
108 |
+
text = '<|image|>' + text
|
109 |
+
text = (text, image, image_process_mode)
|
110 |
+
if len(state.get_images(return_pil=True)) > 0:
|
111 |
+
state = default_conversation.copy()
|
112 |
+
state.append_message(state.roles[0], text)
|
113 |
+
state.append_message(state.roles[1], None)
|
114 |
+
state.skip_next = False
|
115 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
116 |
+
|
117 |
+
|
118 |
+
def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
|
119 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
120 |
+
start_tstamp = time.time()
|
121 |
+
|
122 |
+
if state.skip_next:
|
123 |
+
# This generate call is skipped due to invalid inputs
|
124 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
125 |
+
return
|
126 |
+
|
127 |
+
if len(state.messages) == state.offset + 2:
|
128 |
+
# First round of conversation
|
129 |
+
template_name = "mplug_owl2"
|
130 |
+
new_state = conv_templates[template_name].copy()
|
131 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
132 |
+
new_state.append_message(new_state.roles[1], None)
|
133 |
+
state = new_state
|
134 |
+
|
135 |
+
# Construct prompt
|
136 |
+
prompt = state.get_prompt()
|
137 |
+
|
138 |
+
all_images = state.get_images(return_pil=True)
|
139 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
140 |
+
for image, hash in zip(all_images, all_image_hash):
|
141 |
+
t = datetime.datetime.now()
|
142 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
143 |
+
if not os.path.isfile(filename):
|
144 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
145 |
+
image.save(filename)
|
146 |
+
|
147 |
+
# Make requests
|
148 |
+
pload = {
|
149 |
+
"prompt": prompt,
|
150 |
+
"temperature": float(temperature),
|
151 |
+
"top_p": float(top_p),
|
152 |
+
"max_new_tokens": min(int(max_new_tokens), 2048),
|
153 |
+
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
154 |
+
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
155 |
+
}
|
156 |
+
logger.info(f"==== request ====\n{pload}")
|
157 |
+
|
158 |
+
pload['images'] = state.get_images()
|
159 |
+
|
160 |
+
state.messages[-1][-1] = "▌"
|
161 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
162 |
+
|
163 |
+
try:
|
164 |
+
# Stream output
|
165 |
+
# response = requests.post(worker_addr + "/worker_generate_stream",
|
166 |
+
# headers=headers, json=pload, stream=True, timeout=10)
|
167 |
+
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
168 |
+
response = model.generate_stream_gate(pload)
|
169 |
+
for chunk in response:
|
170 |
+
if chunk:
|
171 |
+
data = json.loads(chunk.decode())
|
172 |
+
if data["error_code"] == 0:
|
173 |
+
output = data["text"][len(prompt):].strip()
|
174 |
+
state.messages[-1][-1] = output + "▌"
|
175 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
176 |
+
else:
|
177 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
178 |
+
state.messages[-1][-1] = output
|
179 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
180 |
+
return
|
181 |
+
time.sleep(0.03)
|
182 |
+
except requests.exceptions.RequestException as e:
|
183 |
+
state.messages[-1][-1] = server_error_msg
|
184 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
185 |
+
return
|
186 |
+
|
187 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
188 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
189 |
+
|
190 |
+
finish_tstamp = time.time()
|
191 |
+
logger.info(f"{output}")
|
192 |
+
|
193 |
+
with open(get_conv_log_filename(), "a") as fout:
|
194 |
+
data = {
|
195 |
+
"tstamp": round(finish_tstamp, 4),
|
196 |
+
"type": "chat",
|
197 |
+
"start": round(start_tstamp, 4),
|
198 |
+
"finish": round(start_tstamp, 4),
|
199 |
+
"state": state.dict(),
|
200 |
+
"images": all_image_hash,
|
201 |
+
"ip": request.client.host,
|
202 |
+
}
|
203 |
+
fout.write(json.dumps(data) + "\n")
|
204 |
+
|
205 |
+
|
206 |
+
title_markdown = ("""
|
207 |
+
<h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
|
208 |
+
|
209 |
+
<h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
|
210 |
+
|
211 |
+
<h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
|
212 |
+
|
213 |
+
<div align="center">
|
214 |
+
<div style="display:flex; gap: 0.25rem;" align="center">
|
215 |
+
<a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
216 |
+
<a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
|
217 |
+
<a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
|
218 |
+
</div>
|
219 |
+
</div>
|
220 |
+
|
221 |
+
""")
|
222 |
+
|
223 |
+
|
224 |
+
tos_markdown = ("""
|
225 |
+
### Terms of use
|
226 |
+
By using this service, users are required to agree to the following terms:
|
227 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
228 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
229 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
230 |
+
""")
|
231 |
+
|
232 |
+
|
233 |
+
learn_more_markdown = ("""
|
234 |
+
### License
|
235 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
236 |
+
""")
|
237 |
+
|
238 |
+
block_css = """
|
239 |
+
|
240 |
+
#buttons button {
|
241 |
+
min-width: min(120px,100%);
|
242 |
+
}
|
243 |
+
|
244 |
+
"""
|
245 |
+
|
246 |
+
def build_demo(embed_mode):
|
247 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
248 |
+
with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
|
249 |
+
state = gr.State()
|
250 |
+
|
251 |
+
if not embed_mode:
|
252 |
+
gr.Markdown(title_markdown)
|
253 |
+
|
254 |
+
with gr.Row():
|
255 |
+
with gr.Column(scale=3):
|
256 |
+
imagebox = gr.Image(type="pil")
|
257 |
+
image_process_mode = gr.Radio(
|
258 |
+
["Crop", "Resize", "Pad", "Default"],
|
259 |
+
value="Default",
|
260 |
+
label="Preprocess for non-square image", visible=False)
|
261 |
+
|
262 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
263 |
+
gr.Examples(examples=[
|
264 |
+
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
265 |
+
[f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
|
266 |
+
], inputs=[imagebox, textbox])
|
267 |
+
|
268 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
269 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
270 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
271 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
272 |
+
|
273 |
+
with gr.Column(scale=8):
|
274 |
+
chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
|
275 |
+
with gr.Row():
|
276 |
+
with gr.Column(scale=8):
|
277 |
+
textbox.render()
|
278 |
+
with gr.Column(scale=1, min_width=50):
|
279 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
280 |
+
with gr.Row(elem_id="buttons") as button_row:
|
281 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
282 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
283 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
284 |
+
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
285 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
286 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
287 |
+
|
288 |
+
if not embed_mode:
|
289 |
+
gr.Markdown(tos_markdown)
|
290 |
+
gr.Markdown(learn_more_markdown)
|
291 |
+
url_params = gr.JSON(visible=False)
|
292 |
+
|
293 |
+
# Register listeners
|
294 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
295 |
+
upvote_btn.click(
|
296 |
+
upvote_last_response,
|
297 |
+
state,
|
298 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
299 |
+
queue=False
|
300 |
+
)
|
301 |
+
downvote_btn.click(
|
302 |
+
downvote_last_response,
|
303 |
+
state,
|
304 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
305 |
+
queue=False
|
306 |
+
)
|
307 |
+
flag_btn.click(
|
308 |
+
flag_last_response,
|
309 |
+
state,
|
310 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
311 |
+
queue=False
|
312 |
+
)
|
313 |
+
|
314 |
+
regenerate_btn.click(
|
315 |
+
regenerate,
|
316 |
+
[state, image_process_mode],
|
317 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
318 |
+
queue=False
|
319 |
+
).then(
|
320 |
+
http_bot,
|
321 |
+
[state, temperature, top_p, max_output_tokens],
|
322 |
+
[state, chatbot] + btn_list
|
323 |
+
)
|
324 |
+
|
325 |
+
clear_btn.click(
|
326 |
+
clear_history,
|
327 |
+
None,
|
328 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
329 |
+
queue=False
|
330 |
+
)
|
331 |
+
|
332 |
+
textbox.submit(
|
333 |
+
add_text,
|
334 |
+
[state, textbox, imagebox, image_process_mode],
|
335 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
336 |
+
queue=False
|
337 |
+
).then(
|
338 |
+
http_bot,
|
339 |
+
[state, temperature, top_p, max_output_tokens],
|
340 |
+
[state, chatbot] + btn_list
|
341 |
+
)
|
342 |
+
|
343 |
+
submit_btn.click(
|
344 |
+
add_text,
|
345 |
+
[state, textbox, imagebox, image_process_mode],
|
346 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
347 |
+
queue=False
|
348 |
+
).then(
|
349 |
+
http_bot,
|
350 |
+
[state, temperature, top_p, max_output_tokens],
|
351 |
+
[state, chatbot] + btn_list
|
352 |
+
)
|
353 |
+
|
354 |
+
demo.load(
|
355 |
+
load_demo,
|
356 |
+
[url_params],
|
357 |
+
state,
|
358 |
+
_js=get_window_url_params,
|
359 |
+
queue=False
|
360 |
+
)
|
361 |
+
|
362 |
+
return demo
|
363 |
+
|
364 |
+
|
365 |
+
if __name__ == "__main__":
|
366 |
+
parser = argparse.ArgumentParser()
|
367 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
368 |
+
parser.add_argument("--port", type=int)
|
369 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
370 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
371 |
+
choices=["once", "reload"])
|
372 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
373 |
+
parser.add_argument("--device", type=str, default="cuda")
|
374 |
+
parser.add_argument("--load-8bit", action="store_true")
|
375 |
+
parser.add_argument("--load-4bit", action="store_true")
|
376 |
+
parser.add_argument("--moderate", action="store_true")
|
377 |
+
parser.add_argument("--embed", action="store_true")
|
378 |
+
args = parser.parse_args()
|
379 |
+
logger.info(f"args: {args}")
|
380 |
+
|
381 |
+
model = ModelWorker(args.model_path, None, None, args.load_8bit, args.load_4bit, args.device)
|
382 |
+
|
383 |
+
logger.info(args)
|
384 |
+
demo = build_demo(args.embed)
|
385 |
+
demo.queue(
|
386 |
+
concurrency_count=args.concurrency_count,
|
387 |
+
api_open=False
|
388 |
+
).launch(
|
389 |
+
server_name=args.host,
|
390 |
+
server_port=args.port,
|
391 |
+
share=False
|
392 |
+
)
|
mplug_docowl/local_serve/model_worker.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
import uuid
|
10 |
+
|
11 |
+
import requests
|
12 |
+
import torch
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
|
16 |
+
from mplug_owl2.utils import (build_logger, server_error_msg,
|
17 |
+
pretty_print_semaphore)
|
18 |
+
from mplug_owl2.model.builder import load_pretrained_model
|
19 |
+
from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
|
20 |
+
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
21 |
+
from transformers import TextIteratorStreamer
|
22 |
+
from threading import Thread
|
23 |
+
|
24 |
+
GB = 1 << 30
|
25 |
+
|
26 |
+
worker_id = str(uuid.uuid4())[:6]
|
27 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
28 |
+
|
29 |
+
class ModelWorker:
|
30 |
+
def __init__(self, model_path, model_base, model_name, load_8bit, load_4bit, device):
|
31 |
+
self.worker_id = worker_id
|
32 |
+
if model_path.endswith("/"):
|
33 |
+
model_path = model_path[:-1]
|
34 |
+
if model_name is None:
|
35 |
+
model_paths = model_path.split("/")
|
36 |
+
if model_paths[-1].startswith('checkpoint-'):
|
37 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
38 |
+
else:
|
39 |
+
self.model_name = model_paths[-1]
|
40 |
+
else:
|
41 |
+
self.model_name = model_name
|
42 |
+
|
43 |
+
self.device = device
|
44 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
45 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
46 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
47 |
+
self.is_multimodal = True
|
48 |
+
|
49 |
+
@torch.inference_mode()
|
50 |
+
def generate_stream(self, params):
|
51 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
52 |
+
|
53 |
+
prompt = params["prompt"]
|
54 |
+
ori_prompt = prompt
|
55 |
+
images = params.get("images", None)
|
56 |
+
num_image_tokens = 0
|
57 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
58 |
+
if len(images) > 0:
|
59 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
60 |
+
raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
|
61 |
+
|
62 |
+
images = [load_image_from_base64(image) for image in images]
|
63 |
+
images = process_images(images, image_processor, model.config)
|
64 |
+
|
65 |
+
if type(images) is list:
|
66 |
+
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
67 |
+
else:
|
68 |
+
images = images.to(self.model.device, dtype=torch.float16)
|
69 |
+
|
70 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
71 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
72 |
+
|
73 |
+
num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
|
74 |
+
else:
|
75 |
+
images = None
|
76 |
+
image_args = {"images": images}
|
77 |
+
else:
|
78 |
+
images = None
|
79 |
+
image_args = {}
|
80 |
+
|
81 |
+
temperature = float(params.get("temperature", 1.0))
|
82 |
+
top_p = float(params.get("top_p", 1.0))
|
83 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
|
84 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
85 |
+
stop_str = params.get("stop", None)
|
86 |
+
do_sample = True if temperature > 0.001 else False
|
87 |
+
|
88 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
89 |
+
keywords = [stop_str]
|
90 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
91 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
92 |
+
|
93 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
94 |
+
|
95 |
+
if max_new_tokens < 1:
|
96 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
97 |
+
return
|
98 |
+
|
99 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
100 |
+
inputs=input_ids,
|
101 |
+
do_sample=do_sample,
|
102 |
+
temperature=temperature,
|
103 |
+
top_p=top_p,
|
104 |
+
max_new_tokens=max_new_tokens,
|
105 |
+
streamer=streamer,
|
106 |
+
stopping_criteria=[stopping_criteria],
|
107 |
+
use_cache=True,
|
108 |
+
**image_args
|
109 |
+
))
|
110 |
+
thread.start()
|
111 |
+
|
112 |
+
generated_text = ori_prompt
|
113 |
+
for new_text in streamer:
|
114 |
+
generated_text += new_text
|
115 |
+
if generated_text.endswith(stop_str):
|
116 |
+
generated_text = generated_text[:-len(stop_str)]
|
117 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
118 |
+
|
119 |
+
def generate_stream_gate(self, params):
|
120 |
+
try:
|
121 |
+
for x in self.generate_stream(params):
|
122 |
+
yield x
|
123 |
+
except ValueError as e:
|
124 |
+
print("Caught ValueError:", e)
|
125 |
+
ret = {
|
126 |
+
"text": server_error_msg,
|
127 |
+
"error_code": 1,
|
128 |
+
}
|
129 |
+
yield json.dumps(ret).encode()
|
130 |
+
except torch.cuda.CudaError as e:
|
131 |
+
print("Caught torch.cuda.CudaError:", e)
|
132 |
+
ret = {
|
133 |
+
"text": server_error_msg,
|
134 |
+
"error_code": 1,
|
135 |
+
}
|
136 |
+
yield json.dumps(ret).encode()
|
137 |
+
except Exception as e:
|
138 |
+
print("Caught Unknown Error", e)
|
139 |
+
ret = {
|
140 |
+
"text": server_error_msg,
|
141 |
+
"error_code": 1,
|
142 |
+
}
|
143 |
+
yield json.dumps(ret).encode()
|
mplug_docowl/mm_utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from mplug_docowl.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
|
8 |
+
from icecream import ic
|
9 |
+
|
10 |
+
|
11 |
+
def load_image_from_base64(image):
|
12 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
13 |
+
|
14 |
+
|
15 |
+
def expand2square(pil_img, background_color):
|
16 |
+
width, height = pil_img.size
|
17 |
+
if width == height:
|
18 |
+
return pil_img
|
19 |
+
elif width > height:
|
20 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
21 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
22 |
+
return result
|
23 |
+
else:
|
24 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
25 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
def process_images(images, image_processor, model_cfg=None):
|
30 |
+
if model_cfg is not None:
|
31 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
32 |
+
else:
|
33 |
+
image_aspect_ratio = 'resize'
|
34 |
+
new_images = []
|
35 |
+
if image_aspect_ratio == 'pad':
|
36 |
+
for image in images:
|
37 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
38 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
39 |
+
new_images.append(image)
|
40 |
+
elif image_aspect_ratio == 'resize':
|
41 |
+
for image in images:
|
42 |
+
max_edge = max(image.size)
|
43 |
+
image = image.resize((max_edge, max_edge))
|
44 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
45 |
+
new_images.append(image)
|
46 |
+
else:
|
47 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
48 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
49 |
+
new_images = torch.stack(new_images, dim=0)
|
50 |
+
return new_images
|
51 |
+
|
52 |
+
|
53 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
54 |
+
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
55 |
+
|
56 |
+
def insert_separator(X, sep):
|
57 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
58 |
+
|
59 |
+
input_ids = []
|
60 |
+
offset = 0
|
61 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
62 |
+
offset = 1
|
63 |
+
input_ids.append(prompt_chunks[0][0])
|
64 |
+
|
65 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
66 |
+
input_ids.extend(x[offset:])
|
67 |
+
|
68 |
+
if return_tensors is not None:
|
69 |
+
if return_tensors == 'pt':
|
70 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
71 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
72 |
+
return input_ids
|
73 |
+
|
74 |
+
|
75 |
+
def get_model_name_from_path(model_path):
|
76 |
+
model_path = model_path.strip("/")
|
77 |
+
model_paths = model_path.split("/")
|
78 |
+
if model_paths[-1].startswith('checkpoint-'):
|
79 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
80 |
+
else:
|
81 |
+
return model_paths[-1]
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
87 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
88 |
+
self.keywords = keywords
|
89 |
+
self.keyword_ids = []
|
90 |
+
self.max_keyword_len = 0
|
91 |
+
for keyword in keywords:
|
92 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
93 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
94 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
95 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
96 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
97 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
98 |
+
self.tokenizer = tokenizer
|
99 |
+
self.start_len = input_ids.shape[1]
|
100 |
+
|
101 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
102 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
103 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
104 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
105 |
+
for keyword_id in self.keyword_ids:
|
106 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
107 |
+
return True
|
108 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
109 |
+
for keyword in self.keywords:
|
110 |
+
if keyword in outputs:
|
111 |
+
return True
|
112 |
+
return False
|
mplug_docowl/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .modeling_mplug_docowl import MPLUGDocOwlLlamaForCausalLM
|
2 |
+
from .configuration_mplug_docowl import MPLUGDocOwlConfig
|
mplug_docowl/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (315 Bytes). View file
|
|
mplug_docowl/model/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (1.64 kB). View file
|
|
mplug_docowl/model/__pycache__/configuration_mplug_docowl.cpython-310.pyc
ADDED
Binary file (12.9 kB). View file
|
|
mplug_docowl/model/__pycache__/configuration_mplug_docowl2.cpython-310.pyc
ADDED
Binary file (14 kB). View file
|
|
mplug_docowl/model/__pycache__/convert_mplug_docowl2_weight_to_hf.cpython-310.pyc
ADDED
Binary file (9.28 kB). View file
|
|
mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf.cpython-310.pyc
ADDED
Binary file (9.12 kB). View file
|
|
mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf_v2.cpython-310.pyc
ADDED
Binary file (9.07 kB). View file
|
|
mplug_docowl/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc
ADDED
Binary file (7.6 kB). View file
|
|
mplug_docowl/model/__pycache__/modeling_llama2.cpython-310.pyc
ADDED
Binary file (13.3 kB). View file
|
|
mplug_docowl/model/__pycache__/modeling_mplug_docowl.cpython-310.pyc
ADDED
Binary file (9.28 kB). View file
|
|
mplug_docowl/model/__pycache__/modeling_mplug_docowl2.cpython-310.pyc
ADDED
Binary file (10 kB). View file
|
|
mplug_docowl/model/__pycache__/visual_encoder.cpython-310.pyc
ADDED
Binary file (15.1 kB). View file
|
|
mplug_docowl/model/builder.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
|
22 |
+
import torch
|
23 |
+
from mplug_docowl.model import *
|
24 |
+
from icecream import ic
|
25 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
|
26 |
+
kwargs = {"device_map": device_map}
|
27 |
+
|
28 |
+
if device != "cuda":
|
29 |
+
kwargs['device_map'] = {"": device}
|
30 |
+
|
31 |
+
if load_8bit:
|
32 |
+
kwargs['load_in_8bit'] = True
|
33 |
+
elif load_4bit:
|
34 |
+
kwargs['load_in_4bit'] = True
|
35 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
36 |
+
load_in_4bit=True,
|
37 |
+
bnb_4bit_compute_dtype=torch.float16,
|
38 |
+
bnb_4bit_use_double_quant=True,
|
39 |
+
bnb_4bit_quant_type='nf4'
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
kwargs['torch_dtype'] = torch.float16
|
43 |
+
if 'paperowl' or 'docowl' in model_name.lower():
|
44 |
+
if model_base is not None:
|
45 |
+
# this may be mm projector only
|
46 |
+
print('Loading mPLUG-DocOwl from base model...')
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
48 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
49 |
+
model = MPLUGDocOwlLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
50 |
+
else:
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
52 |
+
model = MPLUGDocOwlLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
53 |
+
else:
|
54 |
+
# Load language model
|
55 |
+
if model_base is not None:
|
56 |
+
# PEFT model
|
57 |
+
from peft import PeftModel
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
59 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
60 |
+
print(f"Loading LoRA weights from {model_path}")
|
61 |
+
model = PeftModel.from_pretrained(model, model_path)
|
62 |
+
print(f"Merging weights")
|
63 |
+
model = model.merge_and_unload()
|
64 |
+
print('Convert to FP16...')
|
65 |
+
model.to(torch.float16)
|
66 |
+
else:
|
67 |
+
use_fast = False
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
69 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
70 |
+
|
71 |
+
|
72 |
+
# vision_tower = model.get_model().vision_model
|
73 |
+
# vision_tower.to(device=device, dtype=torch.float16)
|
74 |
+
image_processor = CLIPImageProcessor.from_pretrained(model_path)
|
75 |
+
|
76 |
+
if hasattr(model.config, "max_sequence_length"):
|
77 |
+
context_len = model.config.max_sequence_length
|
78 |
+
else:
|
79 |
+
context_len = 2048
|
80 |
+
|
81 |
+
return tokenizer, model, image_processor, context_len
|
mplug_docowl/model/configuration_mplug_docowl.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import copy
|
6 |
+
import os
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from transformers.configuration_utils import PretrainedConfig
|
10 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
11 |
+
from transformers.utils import logging
|
12 |
+
from transformers.models.auto import CONFIG_MAPPING
|
13 |
+
|
14 |
+
|
15 |
+
class LlamaConfig(PretrainedConfig):
|
16 |
+
r"""
|
17 |
+
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
|
18 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
19 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
20 |
+
|
21 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
22 |
+
documentation from [`PretrainedConfig`] for more information.
|
23 |
+
|
24 |
+
|
25 |
+
Args:
|
26 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
27 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
28 |
+
`inputs_ids` passed when calling [`LlamaModel`]
|
29 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
30 |
+
Dimension of the hidden representations.
|
31 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
32 |
+
Dimension of the MLP representations.
|
33 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
34 |
+
Number of hidden layers in the Transformer decoder.
|
35 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
36 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
37 |
+
num_key_value_heads (`int`, *optional*):
|
38 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
39 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
40 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
41 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
42 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
43 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
44 |
+
`num_attention_heads`.
|
45 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
46 |
+
The non-linear activation function (function or string) in the decoder.
|
47 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
48 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
49 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
50 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
51 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
52 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
53 |
+
The epsilon used by the rms normalization layers.
|
54 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
55 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
56 |
+
relevant if `config.is_decoder=True`.
|
57 |
+
pad_token_id (`int`, *optional*):
|
58 |
+
Padding token id.
|
59 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
60 |
+
Beginning of stream token id.
|
61 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
62 |
+
End of stream token id.
|
63 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
64 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
65 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
66 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
67 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
68 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
69 |
+
Whether to tie weight embeddings
|
70 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
71 |
+
The base period of the RoPE embeddings.
|
72 |
+
rope_scaling (`Dict`, *optional*):
|
73 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
74 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
75 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
76 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
77 |
+
these scaling strategies behave:
|
78 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
79 |
+
experimental feature, subject to breaking API changes in future versions.
|
80 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
81 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
82 |
+
|
83 |
+
|
84 |
+
```python
|
85 |
+
>>> from transformers import LlamaModel, LlamaConfig
|
86 |
+
|
87 |
+
>>> # Initializing a LLaMA llama-7b style configuration
|
88 |
+
>>> configuration = LlamaConfig()
|
89 |
+
|
90 |
+
>>> # Initializing a model from the llama-7b style configuration
|
91 |
+
>>> model = LlamaModel(configuration)
|
92 |
+
|
93 |
+
>>> # Accessing the model configuration
|
94 |
+
>>> configuration = model.config
|
95 |
+
```"""
|
96 |
+
model_type = "llama"
|
97 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
vocab_size=32000,
|
102 |
+
hidden_size=4096,
|
103 |
+
intermediate_size=11008,
|
104 |
+
num_hidden_layers=32,
|
105 |
+
num_attention_heads=32,
|
106 |
+
num_key_value_heads=None,
|
107 |
+
hidden_act="silu",
|
108 |
+
max_position_embeddings=2048,
|
109 |
+
initializer_range=0.02,
|
110 |
+
rms_norm_eps=1e-6,
|
111 |
+
use_cache=True,
|
112 |
+
pad_token_id=None,
|
113 |
+
bos_token_id=1,
|
114 |
+
eos_token_id=2,
|
115 |
+
pretraining_tp=1,
|
116 |
+
tie_word_embeddings=False,
|
117 |
+
rope_theta=10000.0,
|
118 |
+
rope_scaling=None,
|
119 |
+
attention_bias=False,
|
120 |
+
**kwargs,
|
121 |
+
):
|
122 |
+
self.vocab_size = vocab_size
|
123 |
+
self.max_position_embeddings = max_position_embeddings
|
124 |
+
self.hidden_size = hidden_size
|
125 |
+
self.intermediate_size = intermediate_size
|
126 |
+
self.num_hidden_layers = num_hidden_layers
|
127 |
+
self.num_attention_heads = num_attention_heads
|
128 |
+
|
129 |
+
# for backward compatibility
|
130 |
+
if num_key_value_heads is None:
|
131 |
+
num_key_value_heads = num_attention_heads
|
132 |
+
|
133 |
+
self.num_key_value_heads = num_key_value_heads
|
134 |
+
self.hidden_act = hidden_act
|
135 |
+
self.initializer_range = initializer_range
|
136 |
+
self.rms_norm_eps = rms_norm_eps
|
137 |
+
self.pretraining_tp = pretraining_tp
|
138 |
+
self.use_cache = use_cache
|
139 |
+
self.rope_theta = rope_theta
|
140 |
+
self.rope_scaling = rope_scaling
|
141 |
+
self._rope_scaling_validation()
|
142 |
+
self.attention_bias = attention_bias
|
143 |
+
|
144 |
+
super().__init__(
|
145 |
+
pad_token_id=pad_token_id,
|
146 |
+
bos_token_id=bos_token_id,
|
147 |
+
eos_token_id=eos_token_id,
|
148 |
+
tie_word_embeddings=tie_word_embeddings,
|
149 |
+
**kwargs,
|
150 |
+
)
|
151 |
+
|
152 |
+
def _rope_scaling_validation(self):
|
153 |
+
"""
|
154 |
+
Validate the `rope_scaling` configuration.
|
155 |
+
"""
|
156 |
+
if self.rope_scaling is None:
|
157 |
+
return
|
158 |
+
|
159 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
160 |
+
raise ValueError(
|
161 |
+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
162 |
+
f"got {self.rope_scaling}"
|
163 |
+
)
|
164 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
165 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
166 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
167 |
+
raise ValueError(
|
168 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
169 |
+
)
|
170 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
171 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
172 |
+
|
173 |
+
|
174 |
+
class MplugOwlVisionConfig(PretrainedConfig):
|
175 |
+
r"""
|
176 |
+
This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
|
177 |
+
a
|
178 |
+
mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
179 |
+
configuration defaults will yield a similar configuration to that of the mPLUG-Owl
|
180 |
+
[x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
|
181 |
+
|
182 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
183 |
+
documentation from [`PretrainedConfig`] for more information.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
187 |
+
Dimensionality of the encoder layers and the pooler layer.
|
188 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
189 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
190 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
191 |
+
Number of hidden layers in the Transformer encoder.
|
192 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
193 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
194 |
+
image_size (`int`, *optional*, defaults to 224):
|
195 |
+
The size (resolution) of each image.
|
196 |
+
patch_size (`int`, *optional*, defaults to 32):
|
197 |
+
The size (resolution) of each patch.
|
198 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
199 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
200 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
201 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
202 |
+
The epsilon used by the layer normalization layers.
|
203 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
204 |
+
The dropout ratio for the attention probabilities.
|
205 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
206 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
207 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
208 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
209 |
+
testing).
|
210 |
+
|
211 |
+
|
212 |
+
```"""
|
213 |
+
|
214 |
+
model_type = "mplug_owl_vision_model"
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
hidden_size=1024,
|
219 |
+
intermediate_size=4096,
|
220 |
+
projection_dim=768,
|
221 |
+
num_hidden_layers=24,
|
222 |
+
num_attention_heads=16,
|
223 |
+
num_channels=3,
|
224 |
+
image_size=448,
|
225 |
+
patch_size=14,
|
226 |
+
hidden_act="quick_gelu",
|
227 |
+
layer_norm_eps=1e-6,
|
228 |
+
attention_dropout=0.0,
|
229 |
+
initializer_range=0.02,
|
230 |
+
initializer_factor=1.0,
|
231 |
+
use_flash_attn=False,
|
232 |
+
**kwargs,
|
233 |
+
):
|
234 |
+
super().__init__(**kwargs)
|
235 |
+
self.hidden_size = hidden_size
|
236 |
+
self.intermediate_size = intermediate_size
|
237 |
+
self.projection_dim = projection_dim
|
238 |
+
self.num_hidden_layers = num_hidden_layers
|
239 |
+
self.num_attention_heads = num_attention_heads
|
240 |
+
self.num_channels = num_channels
|
241 |
+
self.patch_size = patch_size
|
242 |
+
self.image_size = image_size
|
243 |
+
self.initializer_range = initializer_range
|
244 |
+
self.initializer_factor = initializer_factor
|
245 |
+
self.attention_dropout = attention_dropout
|
246 |
+
self.layer_norm_eps = layer_norm_eps
|
247 |
+
self.hidden_act = hidden_act
|
248 |
+
self.use_flash_attn = use_flash_attn
|
249 |
+
|
250 |
+
@classmethod
|
251 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
252 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
253 |
+
|
254 |
+
# get the vision config dict if we are loading from MplugOwlConfig
|
255 |
+
if config_dict.get("model_type") == "mplug-owl":
|
256 |
+
config_dict = config_dict["vision_config"]
|
257 |
+
|
258 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
259 |
+
logger.warning(
|
260 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
261 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
262 |
+
)
|
263 |
+
|
264 |
+
return cls.from_dict(config_dict, **kwargs)
|
265 |
+
|
266 |
+
|
267 |
+
class MplugDocOwlHReducerConfig(PretrainedConfig):
|
268 |
+
model_type = "mplug_docowl_hreducer"
|
269 |
+
|
270 |
+
def __init__(
|
271 |
+
self,
|
272 |
+
hidden_size=1024,
|
273 |
+
initializer_range=0.02,
|
274 |
+
layer_norm_eps=1e-6,
|
275 |
+
conv_shape='1x4',
|
276 |
+
**kwargs,
|
277 |
+
):
|
278 |
+
super().__init__(**kwargs)
|
279 |
+
self.hidden_size = hidden_size
|
280 |
+
self.initializer_range = initializer_range
|
281 |
+
self.layer_norm_eps = layer_norm_eps
|
282 |
+
self.conv_shape = conv_shape
|
283 |
+
|
284 |
+
@classmethod
|
285 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
286 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
287 |
+
|
288 |
+
# get the visual_abstractor config dict if we are loading from MplugOwlConfig
|
289 |
+
if config_dict.get("model_type") == "mplug-docowl":
|
290 |
+
config_dict = config_dict["hreducer_config"]
|
291 |
+
|
292 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
293 |
+
logger.warning(
|
294 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
295 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
296 |
+
)
|
297 |
+
|
298 |
+
return cls.from_dict(config_dict, **kwargs)
|
299 |
+
|
300 |
+
DEFAULT_VISUAL_CONFIG = {
|
301 |
+
"visual_model": MplugOwlVisionConfig().to_dict(),
|
302 |
+
"visual_hreducer": MplugDocOwlHReducerConfig().to_dict()
|
303 |
+
}
|
304 |
+
|
305 |
+
class MPLUGDocOwlConfig(LlamaConfig):
|
306 |
+
model_type = "mplug_docowl"
|
307 |
+
def __init__(self, visual_config=None, **kwargs):
|
308 |
+
if visual_config is None:
|
309 |
+
self.visual_config = DEFAULT_VISUAL_CONFIG
|
310 |
+
else:
|
311 |
+
self.visual_config = visual_config
|
312 |
+
|
313 |
+
super().__init__(
|
314 |
+
**kwargs,
|
315 |
+
)
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
print(MplugOwlVisionConfig().to_dict())
|
mplug_docowl/model/convert_mplug_docowl_weight_to_hf.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import gc
|
16 |
+
import json
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
import shutil
|
20 |
+
import warnings
|
21 |
+
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from transformers import LlamaTokenizer
|
25 |
+
from .configuration_mplug_docowl import MPLUGDocOwlConfig
|
26 |
+
from icecream import ic
|
27 |
+
|
28 |
+
try:
|
29 |
+
from transformers import LlamaTokenizerFast
|
30 |
+
except ImportError as e:
|
31 |
+
warnings.warn(e)
|
32 |
+
warnings.warn(
|
33 |
+
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
34 |
+
)
|
35 |
+
LlamaTokenizerFast = None
|
36 |
+
|
37 |
+
"""
|
38 |
+
Sample usage:
|
39 |
+
|
40 |
+
```
|
41 |
+
python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
|
42 |
+
--input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
|
43 |
+
```
|
44 |
+
|
45 |
+
Thereafter, models can be loaded via:
|
46 |
+
|
47 |
+
```py
|
48 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
49 |
+
|
50 |
+
model = LlamaForCausalLM.from_pretrained("/output/path")
|
51 |
+
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
52 |
+
```
|
53 |
+
|
54 |
+
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
55 |
+
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
56 |
+
"""
|
57 |
+
|
58 |
+
llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
|
59 |
+
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
|
60 |
+
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
|
61 |
+
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
|
62 |
+
llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
|
63 |
+
|
64 |
+
|
65 |
+
def compute_intermediate_size(n):
|
66 |
+
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
|
67 |
+
|
68 |
+
|
69 |
+
def read_json(path):
|
70 |
+
with open(path, "r") as f:
|
71 |
+
return json.load(f)
|
72 |
+
|
73 |
+
|
74 |
+
def write_json(text, path):
|
75 |
+
with open(path, "w") as f:
|
76 |
+
json.dump(text, f)
|
77 |
+
|
78 |
+
|
79 |
+
def write_model(model_path,
|
80 |
+
input_base_path,
|
81 |
+
model_size,
|
82 |
+
num_input_shards=1,
|
83 |
+
num_output_shards=2,
|
84 |
+
skip_permute=True,
|
85 |
+
norm_eps=1e-05):
|
86 |
+
# if os.path.exists(model_path):
|
87 |
+
# shutil.rmtree(model_path)
|
88 |
+
os.makedirs(model_path, exist_ok=True)
|
89 |
+
# tmp_model_path = os.path.join(model_path, "tmp")
|
90 |
+
tmp_model_path = model_path
|
91 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
92 |
+
|
93 |
+
num_shards = num_input_shards
|
94 |
+
n_layers = llama_s2layer[model_size]
|
95 |
+
n_heads = llama_s2heads[model_size]
|
96 |
+
n_heads_per_shard = n_heads // num_shards
|
97 |
+
n_dense = llama_s2dense[model_size]
|
98 |
+
n_hidden = llama_s2hidden[model_size]
|
99 |
+
hidden_per_head = n_hidden // n_heads
|
100 |
+
base = 10000.0
|
101 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
|
102 |
+
|
103 |
+
# permute for sliced rotary
|
104 |
+
def permute(w, skip_permute=skip_permute):
|
105 |
+
if skip_permute:
|
106 |
+
return w
|
107 |
+
return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
|
108 |
+
|
109 |
+
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
110 |
+
# Load weights
|
111 |
+
if num_shards==1:
|
112 |
+
# Not sharded
|
113 |
+
# (The sharded implementation would also work, but this is simpler.)
|
114 |
+
# /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
|
115 |
+
if os.path.exists(os.path.join(input_base_path, 'release')):
|
116 |
+
filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
|
117 |
+
elif input_base_path.split('/')[-1].startswith('iter_'):
|
118 |
+
iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
|
119 |
+
load_dir = '/'.join(input_base_path.split('/')[:-1])
|
120 |
+
filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
|
121 |
+
if not os.path.exists(filename):
|
122 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
123 |
+
else:
|
124 |
+
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
|
125 |
+
with open(tracker_filename, 'r') as f:
|
126 |
+
metastring = f.read().strip()
|
127 |
+
iteration = 'iter_{:07d}'.format(int(metastring))
|
128 |
+
filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
|
129 |
+
if not os.path.exists(filename):
|
130 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
131 |
+
original_filename = filename
|
132 |
+
loaded = torch.load(filename, map_location="cpu")['model']['language_model']
|
133 |
+
|
134 |
+
else:
|
135 |
+
# Sharded
|
136 |
+
filenames = []
|
137 |
+
for i in range(num_shards):
|
138 |
+
if os.path.exists(os.path.join(input_base_path, 'release')):
|
139 |
+
filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
|
140 |
+
else:
|
141 |
+
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
|
142 |
+
with open(tracker_filename, 'r') as f:
|
143 |
+
metastring = f.read().strip()
|
144 |
+
iteration = 'iter_{:07d}'.format(int(metastring))
|
145 |
+
filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
|
146 |
+
if not os.path.exists(filename):
|
147 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
148 |
+
filenames.append(filename)
|
149 |
+
loaded = [
|
150 |
+
torch.load(filenames[i], map_location="cpu")['model']['language_model']
|
151 |
+
for i in range(num_shards)
|
152 |
+
]
|
153 |
+
|
154 |
+
print('Llama-Megatron Loaded!')
|
155 |
+
param_count = 0
|
156 |
+
index_dict = {"weight_map": {}}
|
157 |
+
|
158 |
+
print(f'Weighted Converting for {n_layers} layers...')
|
159 |
+
for layer_i in range(n_layers):
|
160 |
+
print(layer_i)
|
161 |
+
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
162 |
+
if num_shards == 1:
|
163 |
+
# Unsharded
|
164 |
+
state_dict = {
|
165 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
|
166 |
+
f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
|
167 |
+
f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
|
168 |
+
f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
|
169 |
+
f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
|
170 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
|
171 |
+
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
|
172 |
+
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
|
173 |
+
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
|
174 |
+
f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
|
175 |
+
f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
|
176 |
+
f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
|
177 |
+
f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
|
178 |
+
}
|
179 |
+
else:
|
180 |
+
raise NotImplemented
|
181 |
+
|
182 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
183 |
+
for k, v in state_dict.items():
|
184 |
+
index_dict["weight_map"][k] = filename
|
185 |
+
param_count += v.numel()
|
186 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
187 |
+
print(f'Sharded file saved to {filename}')
|
188 |
+
|
189 |
+
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
190 |
+
if num_shards==1:
|
191 |
+
# Unsharded
|
192 |
+
state_dict = {
|
193 |
+
"model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
|
194 |
+
"model.norm.weight": loaded['encoder']['norm.weight'],
|
195 |
+
"lm_head.weight": loaded['encoder']['lm_head.weight'],
|
196 |
+
}
|
197 |
+
else:
|
198 |
+
state_dict = {
|
199 |
+
"model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
|
200 |
+
"model.norm.weight": loaded[0]['encoder']['norm.weight'],
|
201 |
+
"lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
|
202 |
+
}
|
203 |
+
|
204 |
+
|
205 |
+
loaded_all = torch.load(original_filename, map_location="cpu")['model']
|
206 |
+
# Vision Part
|
207 |
+
state_dict.update({
|
208 |
+
"model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
|
209 |
+
"model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
|
210 |
+
"model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
|
211 |
+
"model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
|
212 |
+
"model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
|
213 |
+
"model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
|
214 |
+
"model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
|
215 |
+
})
|
216 |
+
for v_layer_idx in range(24):
|
217 |
+
state_dict.update({
|
218 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
|
219 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
|
220 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
|
221 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
|
222 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
|
223 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
|
224 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
|
225 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
|
226 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
|
227 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
|
228 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
|
229 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
|
230 |
+
})
|
231 |
+
|
232 |
+
# Vision2Text Part: HReducer
|
233 |
+
state_dict.update({
|
234 |
+
"model.vision2text.ln_q.weight": loaded_all['hreducer3']['ln_q']['weight'],
|
235 |
+
"model.vision2text.ln_q.bias": loaded_all['hreducer3']['ln_q']['bias'],
|
236 |
+
"model.vision2text.visual_fc.bias": loaded_all['hreducer3']['visual_fc']['bias'],
|
237 |
+
"model.vision2text.visual_fc.weight": loaded_all['hreducer3']['visual_fc']['weight'],
|
238 |
+
"model.vision2text.vit_eos": loaded_all['hreducer3']['vit_eos'],
|
239 |
+
})
|
240 |
+
# reducer_before conv (layer 0) + gleu (layer 1)
|
241 |
+
state_dict.update({
|
242 |
+
f"model.vision2text.reducer_before.0.weight": loaded_all['hreducer3']['reducer_before']["0.weight"],
|
243 |
+
f"model.vision2text.reducer_before.0.bias": loaded_all['hreducer3']['reducer_before']["0.bias"],
|
244 |
+
})
|
245 |
+
# reducer conv
|
246 |
+
state_dict.update({
|
247 |
+
f"model.vision2text.reducer.weight": loaded_all['hreducer3']['reducer']["weight"],
|
248 |
+
f"model.vision2text.reducer.bias": loaded_all['hreducer3']['reducer']["bias"],
|
249 |
+
})
|
250 |
+
|
251 |
+
for k, v in state_dict.items():
|
252 |
+
# ic(k, v)
|
253 |
+
index_dict["weight_map"][k] = filename
|
254 |
+
param_count += v.numel()
|
255 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
256 |
+
|
257 |
+
# Write configs
|
258 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
259 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
260 |
+
|
261 |
+
config = MPLUGDocOwlConfig()
|
262 |
+
config.save_pretrained(tmp_model_path)
|
263 |
+
|
264 |
+
# Make space so we can load the model properly now.
|
265 |
+
del state_dict
|
266 |
+
del loaded
|
267 |
+
del loaded_all
|
268 |
+
gc.collect()
|
269 |
+
|
270 |
+
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
271 |
+
# Initialize the tokenizer based on the `spm` model
|
272 |
+
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
273 |
+
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
274 |
+
tokenizer = tokenizer_class(input_tokenizer_path)
|
275 |
+
tokenizer.save_pretrained(tokenizer_path)
|
276 |
+
|
277 |
+
|
278 |
+
def main():
|
279 |
+
parser = argparse.ArgumentParser()
|
280 |
+
parser.add_argument(
|
281 |
+
"--input_dir",
|
282 |
+
help="Location of LLaMA_Megatron weights",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--model_size",
|
286 |
+
type=int,
|
287 |
+
default=7,
|
288 |
+
choices=[7, 13, 30, 65, 70],
|
289 |
+
)
|
290 |
+
parser.add_argument(
|
291 |
+
"--num_input_shards",
|
292 |
+
type=int,
|
293 |
+
default=1,
|
294 |
+
)
|
295 |
+
parser.add_argument(
|
296 |
+
"--num_output_shards",
|
297 |
+
type=int,
|
298 |
+
default=1,
|
299 |
+
)
|
300 |
+
parser.add_argument('--skip_permute', action='store_true')
|
301 |
+
|
302 |
+
parser.add_argument(
|
303 |
+
"--output_dir",
|
304 |
+
help="Location to write HF model and tokenizer",
|
305 |
+
)
|
306 |
+
|
307 |
+
args = parser.parse_args()
|
308 |
+
write_model(
|
309 |
+
model_path=args.output_dir,
|
310 |
+
input_base_path=args.input_dir,
|
311 |
+
model_size=args.model_size,
|
312 |
+
num_input_shards=args.num_input_shards,
|
313 |
+
num_output_shards=args.num_output_shards,
|
314 |
+
skip_permute=args.skip_permute
|
315 |
+
)
|
316 |
+
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
main()
|
mplug_docowl/model/convert_mplug_docowl_weight_to_hf_v2.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import gc
|
16 |
+
import json
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
import shutil
|
20 |
+
import warnings
|
21 |
+
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from transformers import LlamaTokenizer
|
25 |
+
from .configuration_mplug_docowl import MPLUGDocOwlConfig
|
26 |
+
from icecream import ic
|
27 |
+
|
28 |
+
try:
|
29 |
+
from transformers import LlamaTokenizerFast
|
30 |
+
except ImportError as e:
|
31 |
+
warnings.warn(e)
|
32 |
+
warnings.warn(
|
33 |
+
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
34 |
+
)
|
35 |
+
LlamaTokenizerFast = None
|
36 |
+
|
37 |
+
"""
|
38 |
+
Sample usage:
|
39 |
+
|
40 |
+
```
|
41 |
+
python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
|
42 |
+
--input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
|
43 |
+
```
|
44 |
+
|
45 |
+
Thereafter, models can be loaded via:
|
46 |
+
|
47 |
+
```py
|
48 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
49 |
+
|
50 |
+
model = LlamaForCausalLM.from_pretrained("/output/path")
|
51 |
+
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
52 |
+
```
|
53 |
+
|
54 |
+
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
55 |
+
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
56 |
+
"""
|
57 |
+
|
58 |
+
llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
|
59 |
+
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
|
60 |
+
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
|
61 |
+
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
|
62 |
+
llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
|
63 |
+
|
64 |
+
|
65 |
+
def compute_intermediate_size(n):
|
66 |
+
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
|
67 |
+
|
68 |
+
|
69 |
+
def read_json(path):
|
70 |
+
with open(path, "r") as f:
|
71 |
+
return json.load(f)
|
72 |
+
|
73 |
+
|
74 |
+
def write_json(text, path):
|
75 |
+
with open(path, "w") as f:
|
76 |
+
json.dump(text, f)
|
77 |
+
|
78 |
+
|
79 |
+
def write_model(model_path,
|
80 |
+
input_base_path,
|
81 |
+
model_size,
|
82 |
+
num_input_shards=1,
|
83 |
+
num_output_shards=2,
|
84 |
+
skip_permute=True,
|
85 |
+
norm_eps=1e-05):
|
86 |
+
# if os.path.exists(model_path):
|
87 |
+
# shutil.rmtree(model_path)
|
88 |
+
os.makedirs(model_path, exist_ok=True)
|
89 |
+
# tmp_model_path = os.path.join(model_path, "tmp")
|
90 |
+
tmp_model_path = model_path
|
91 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
92 |
+
|
93 |
+
num_shards = num_input_shards
|
94 |
+
n_layers = llama_s2layer[model_size]
|
95 |
+
n_heads = llama_s2heads[model_size]
|
96 |
+
n_heads_per_shard = n_heads // num_shards
|
97 |
+
n_dense = llama_s2dense[model_size]
|
98 |
+
n_hidden = llama_s2hidden[model_size]
|
99 |
+
hidden_per_head = n_hidden // n_heads
|
100 |
+
base = 10000.0
|
101 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
|
102 |
+
|
103 |
+
# permute for sliced rotary
|
104 |
+
def permute(w, skip_permute=skip_permute):
|
105 |
+
if skip_permute:
|
106 |
+
return w
|
107 |
+
return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
|
108 |
+
|
109 |
+
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
110 |
+
# Load weights
|
111 |
+
if num_shards==1:
|
112 |
+
# Not sharded
|
113 |
+
# (The sharded implementation would also work, but this is simpler.)
|
114 |
+
# /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
|
115 |
+
if os.path.exists(os.path.join(input_base_path, 'release')):
|
116 |
+
filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
|
117 |
+
elif input_base_path.split('/')[-1].startswith('iter_'):
|
118 |
+
iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
|
119 |
+
load_dir = '/'.join(input_base_path.split('/')[:-1])
|
120 |
+
filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
|
121 |
+
if not os.path.exists(filename):
|
122 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
123 |
+
else:
|
124 |
+
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
|
125 |
+
with open(tracker_filename, 'r') as f:
|
126 |
+
metastring = f.read().strip()
|
127 |
+
iteration = 'iter_{:07d}'.format(int(metastring))
|
128 |
+
filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
|
129 |
+
if not os.path.exists(filename):
|
130 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
131 |
+
original_filename = filename
|
132 |
+
loaded = torch.load(filename, map_location="cpu")['model']['language_model']
|
133 |
+
|
134 |
+
else:
|
135 |
+
# Sharded
|
136 |
+
filenames = []
|
137 |
+
for i in range(num_shards):
|
138 |
+
if os.path.exists(os.path.join(input_base_path, 'release')):
|
139 |
+
filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
|
140 |
+
else:
|
141 |
+
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
|
142 |
+
with open(tracker_filename, 'r') as f:
|
143 |
+
metastring = f.read().strip()
|
144 |
+
iteration = 'iter_{:07d}'.format(int(metastring))
|
145 |
+
filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
|
146 |
+
if not os.path.exists(filename):
|
147 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
148 |
+
filenames.append(filename)
|
149 |
+
loaded = [
|
150 |
+
torch.load(filenames[i], map_location="cpu")['model']['language_model']
|
151 |
+
for i in range(num_shards)
|
152 |
+
]
|
153 |
+
|
154 |
+
print('Llama-Megatron Loaded!')
|
155 |
+
param_count = 0
|
156 |
+
index_dict = {"weight_map": {}}
|
157 |
+
state_dict = {}
|
158 |
+
print(f'Weighted Converting for {n_layers} layers...')
|
159 |
+
for layer_i in range(n_layers):
|
160 |
+
print(layer_i)
|
161 |
+
# filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
162 |
+
if num_shards == 1:
|
163 |
+
# Unsharded
|
164 |
+
state_dict.update({
|
165 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
|
166 |
+
f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
|
167 |
+
f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
|
168 |
+
f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
|
169 |
+
f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
|
170 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
|
171 |
+
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
|
172 |
+
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
|
173 |
+
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
|
174 |
+
f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
|
175 |
+
f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
|
176 |
+
f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
|
177 |
+
f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
|
178 |
+
})
|
179 |
+
else:
|
180 |
+
raise NotImplemented
|
181 |
+
|
182 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
183 |
+
for k, v in state_dict.items():
|
184 |
+
index_dict["weight_map"][k] = filename
|
185 |
+
param_count += v.numel()
|
186 |
+
# torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
187 |
+
# print(f'Sharded file saved to {filename}')
|
188 |
+
|
189 |
+
# filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
190 |
+
filename = "pytorch_model.bin"
|
191 |
+
if num_shards==1:
|
192 |
+
# Unsharded
|
193 |
+
state_dict.update({
|
194 |
+
"model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
|
195 |
+
"model.norm.weight": loaded['encoder']['norm.weight'],
|
196 |
+
"lm_head.weight": loaded['encoder']['lm_head.weight'],
|
197 |
+
})
|
198 |
+
else:
|
199 |
+
state_dict.update({
|
200 |
+
"model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
|
201 |
+
"model.norm.weight": loaded[0]['encoder']['norm.weight'],
|
202 |
+
"lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
|
203 |
+
})
|
204 |
+
|
205 |
+
loaded_all = torch.load(original_filename, map_location="cpu")['model']
|
206 |
+
# Vision Part
|
207 |
+
state_dict.update({
|
208 |
+
"model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
|
209 |
+
"model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
|
210 |
+
"model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
|
211 |
+
"model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
|
212 |
+
"model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
|
213 |
+
"model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
|
214 |
+
"model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
|
215 |
+
})
|
216 |
+
for v_layer_idx in range(24):
|
217 |
+
state_dict.update({
|
218 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
|
219 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
|
220 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
|
221 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
|
222 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
|
223 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
|
224 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
|
225 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
|
226 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
|
227 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
|
228 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
|
229 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
|
230 |
+
})
|
231 |
+
|
232 |
+
# Vision2Text Part: HReducer
|
233 |
+
state_dict.update({
|
234 |
+
"model.vision2text.ln_q.weight": loaded_all['hreducer3']['ln_q']['weight'],
|
235 |
+
"model.vision2text.ln_q.bias": loaded_all['hreducer3']['ln_q']['bias'],
|
236 |
+
"model.vision2text.visual_fc.bias": loaded_all['hreducer3']['visual_fc']['bias'],
|
237 |
+
"model.vision2text.visual_fc.weight": loaded_all['hreducer3']['visual_fc']['weight'],
|
238 |
+
"model.vision2text.vit_eos": loaded_all['hreducer3']['vit_eos'],
|
239 |
+
})
|
240 |
+
# reducer_before conv (layer 0) + gleu (layer 1)
|
241 |
+
state_dict.update({
|
242 |
+
f"model.vision2text.reducer_before.0.weight": loaded_all['hreducer3']['reducer_before']["0.weight"],
|
243 |
+
f"model.vision2text.reducer_before.0.bias": loaded_all['hreducer3']['reducer_before']["0.bias"],
|
244 |
+
})
|
245 |
+
# reducer conv
|
246 |
+
state_dict.update({
|
247 |
+
f"model.vision2text.reducer.weight": loaded_all['hreducer3']['reducer']["weight"],
|
248 |
+
f"model.vision2text.reducer.bias": loaded_all['hreducer3']['reducer']["bias"],
|
249 |
+
})
|
250 |
+
|
251 |
+
for k, v in state_dict.items():
|
252 |
+
# ic(k, v)
|
253 |
+
index_dict["weight_map"][k] = filename
|
254 |
+
param_count += v.numel()
|
255 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
256 |
+
print(f'save to {os.path.join(tmp_model_path, filename)}')
|
257 |
+
|
258 |
+
# Write configs
|
259 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
260 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
261 |
+
|
262 |
+
config = MPLUGDocOwlConfig()
|
263 |
+
config.save_pretrained(tmp_model_path)
|
264 |
+
|
265 |
+
# Make space so we can load the model properly now.
|
266 |
+
del state_dict
|
267 |
+
del loaded
|
268 |
+
del loaded_all
|
269 |
+
gc.collect()
|
270 |
+
|
271 |
+
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
272 |
+
# Initialize the tokenizer based on the `spm` model
|
273 |
+
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
274 |
+
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
275 |
+
tokenizer = tokenizer_class(input_tokenizer_path)
|
276 |
+
tokenizer.save_pretrained(tokenizer_path)
|
277 |
+
|
278 |
+
|
279 |
+
def main():
|
280 |
+
parser = argparse.ArgumentParser()
|
281 |
+
parser.add_argument(
|
282 |
+
"--input_dir",
|
283 |
+
help="Location of LLaMA_Megatron weights",
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--model_size",
|
287 |
+
type=int,
|
288 |
+
default=7,
|
289 |
+
choices=[7, 13, 30, 65, 70],
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--num_input_shards",
|
293 |
+
type=int,
|
294 |
+
default=1,
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--num_output_shards",
|
298 |
+
type=int,
|
299 |
+
default=1,
|
300 |
+
)
|
301 |
+
parser.add_argument('--skip_permute', action='store_true')
|
302 |
+
|
303 |
+
parser.add_argument(
|
304 |
+
"--output_dir",
|
305 |
+
help="Location to write HF model and tokenizer",
|
306 |
+
)
|
307 |
+
|
308 |
+
args = parser.parse_args()
|
309 |
+
write_model(
|
310 |
+
model_path=args.output_dir,
|
311 |
+
input_base_path=args.input_dir,
|
312 |
+
model_size=args.model_size,
|
313 |
+
num_input_shards=args.num_input_shards,
|
314 |
+
num_output_shards=args.num_output_shards,
|
315 |
+
skip_permute=args.skip_permute
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
main()
|
mplug_docowl/model/modeling_attn_mask_utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import List, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
|
19 |
+
class AttentionMaskConverter:
|
20 |
+
"""
|
21 |
+
A utility attention mask class that allows one to:
|
22 |
+
- Create a causal 4d mask
|
23 |
+
- Create a causal 4d mask with slided window
|
24 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
25 |
+
key_value_length) that can be multiplied with attention scores
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
is_causal (`bool`):
|
29 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
30 |
+
|
31 |
+
sliding_window (`int`, *optional*):
|
32 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
36 |
+
self.is_causal = is_causal
|
37 |
+
self.sliding_window = sliding_window
|
38 |
+
|
39 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
40 |
+
raise ValueError(
|
41 |
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
42 |
+
)
|
43 |
+
|
44 |
+
def to_causal_4d(
|
45 |
+
self,
|
46 |
+
batch_size: int,
|
47 |
+
query_length: int,
|
48 |
+
key_value_length: int,
|
49 |
+
dtype: torch.dtype = torch.float32,
|
50 |
+
device: Union[torch.device, "str"] = "cpu",
|
51 |
+
) -> torch.Tensor:
|
52 |
+
"""
|
53 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
54 |
+
bias to upper right hand triangular matrix (causal mask).
|
55 |
+
"""
|
56 |
+
if not self.is_causal:
|
57 |
+
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
|
58 |
+
|
59 |
+
# If shape is not cached, create a new causal mask and cache it
|
60 |
+
input_shape = (batch_size, query_length)
|
61 |
+
past_key_values_length = key_value_length - query_length
|
62 |
+
|
63 |
+
# create causal mask
|
64 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
65 |
+
causal_4d_mask = None
|
66 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
67 |
+
causal_4d_mask = self._make_causal_mask(
|
68 |
+
input_shape,
|
69 |
+
dtype,
|
70 |
+
device=device,
|
71 |
+
past_key_values_length=past_key_values_length,
|
72 |
+
sliding_window=self.sliding_window,
|
73 |
+
)
|
74 |
+
|
75 |
+
return causal_4d_mask
|
76 |
+
|
77 |
+
def to_4d(
|
78 |
+
self,
|
79 |
+
attention_mask_2d: torch.Tensor,
|
80 |
+
query_length: int,
|
81 |
+
key_value_length: Optional[int] = None,
|
82 |
+
dtype: torch.dtype = torch.float32,
|
83 |
+
) -> torch.Tensor:
|
84 |
+
"""
|
85 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
86 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
87 |
+
causal, a causal mask will be added.
|
88 |
+
"""
|
89 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
90 |
+
|
91 |
+
# create causal mask
|
92 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
93 |
+
causal_4d_mask = None
|
94 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
95 |
+
if key_value_length is None:
|
96 |
+
raise ValueError(
|
97 |
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
98 |
+
)
|
99 |
+
|
100 |
+
past_key_values_length = key_value_length - query_length
|
101 |
+
causal_4d_mask = self._make_causal_mask(
|
102 |
+
input_shape,
|
103 |
+
dtype,
|
104 |
+
device=attention_mask_2d.device,
|
105 |
+
past_key_values_length=past_key_values_length,
|
106 |
+
sliding_window=self.sliding_window,
|
107 |
+
)
|
108 |
+
elif self.sliding_window is not None:
|
109 |
+
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
|
110 |
+
|
111 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
112 |
+
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
113 |
+
attention_mask_2d.device
|
114 |
+
)
|
115 |
+
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
|
116 |
+
|
117 |
+
return expanded_4d_mask
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def _make_causal_mask(
|
121 |
+
input_ids_shape: torch.Size,
|
122 |
+
dtype: torch.dtype,
|
123 |
+
device: torch.device,
|
124 |
+
past_key_values_length: int = 0,
|
125 |
+
sliding_window: Optional[int] = None,
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
Make causal mask used for bi-directional self-attention.
|
129 |
+
"""
|
130 |
+
bsz, tgt_len = input_ids_shape
|
131 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
132 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
133 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
134 |
+
|
135 |
+
mask = mask.to(dtype)
|
136 |
+
|
137 |
+
if past_key_values_length > 0:
|
138 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
139 |
+
|
140 |
+
# add lower triangular sliding window mask if necessary
|
141 |
+
if sliding_window is not None:
|
142 |
+
diagonal = past_key_values_length - sliding_window + 1
|
143 |
+
|
144 |
+
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
|
145 |
+
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
|
146 |
+
|
147 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
151 |
+
"""
|
152 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
153 |
+
"""
|
154 |
+
bsz, src_len = mask.size()
|
155 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
156 |
+
|
157 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
158 |
+
|
159 |
+
inverted_mask = 1.0 - expanded_mask
|
160 |
+
|
161 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
162 |
+
|
163 |
+
|
164 |
+
def _prepare_4d_causal_attention_mask(
|
165 |
+
attention_mask: Optional[torch.Tensor],
|
166 |
+
input_shape: Union[torch.Size, Tuple, List],
|
167 |
+
inputs_embeds: torch.Tensor,
|
168 |
+
past_key_values_length: int,
|
169 |
+
sliding_window: Optional[int] = None,
|
170 |
+
):
|
171 |
+
"""
|
172 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
173 |
+
`(batch_size, key_value_length)`
|
174 |
+
|
175 |
+
Args:
|
176 |
+
attention_mask (`torch.Tensor` or `None`):
|
177 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
178 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
179 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
180 |
+
inputs_embeds (`torch.Tensor`):
|
181 |
+
The embedded inputs as a torch Tensor.
|
182 |
+
past_key_values_length (`int`):
|
183 |
+
The length of the key value cache.
|
184 |
+
sliding_window (`int`, *optional*):
|
185 |
+
If the model uses windowed attention, a sliding window should be passed.
|
186 |
+
"""
|
187 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
188 |
+
|
189 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
190 |
+
|
191 |
+
# 4d mask is passed through the layers
|
192 |
+
if attention_mask is not None:
|
193 |
+
attention_mask = attn_mask_converter.to_4d(
|
194 |
+
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
198 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
199 |
+
)
|
200 |
+
|
201 |
+
return attention_mask
|
202 |
+
|
203 |
+
|
204 |
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
205 |
+
"""
|
206 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
207 |
+
`(batch_size, key_value_length)`
|
208 |
+
|
209 |
+
Args:
|
210 |
+
mask (`torch.Tensor` or `None`):
|
211 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
212 |
+
dtype (`torch.dtype`):
|
213 |
+
The torch dtype the created mask shall have.
|
214 |
+
tgt_len (`int`):
|
215 |
+
The target length or query length the created mask shall have.
|
216 |
+
"""
|
217 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
218 |
+
|
219 |
+
|
220 |
+
def _create_4d_causal_attention_mask(
|
221 |
+
input_shape: Union[torch.Size, Tuple, List],
|
222 |
+
dtype: torch.dtype,
|
223 |
+
device: torch.device,
|
224 |
+
past_key_values_length: int = 0,
|
225 |
+
sliding_window: Optional[int] = None,
|
226 |
+
):
|
227 |
+
"""
|
228 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
229 |
+
|
230 |
+
Args:
|
231 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
232 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
233 |
+
dtype (`torch.dtype`):
|
234 |
+
The torch dtype the created mask shall have.
|
235 |
+
device (`int`):
|
236 |
+
The torch device the created mask shall have.
|
237 |
+
sliding_window (`int`, *optional*):
|
238 |
+
If the model uses windowed attention, a sliding window should be passed.
|
239 |
+
"""
|
240 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
241 |
+
|
242 |
+
key_value_length = past_key_values_length + input_shape[-1]
|
243 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
244 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
|
245 |
+
)
|
246 |
+
|
247 |
+
return attention_mask
|
mplug_docowl/model/modeling_llama2.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
import transformers
|
12 |
+
from transformers.models.llama.modeling_llama import *
|
13 |
+
from transformers.configuration_utils import PretrainedConfig
|
14 |
+
from transformers.utils import logging
|
15 |
+
|
16 |
+
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
17 |
+
from .configuration_mplug_docowl import LlamaConfig
|
18 |
+
|
19 |
+
class MultiwayNetwork(nn.Module):
|
20 |
+
|
21 |
+
def __init__(self, module_provider, num_multiway=2):
|
22 |
+
super(MultiwayNetwork, self).__init__()
|
23 |
+
|
24 |
+
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
25 |
+
|
26 |
+
def forward(self, hidden_states, multiway_indices):
|
27 |
+
|
28 |
+
if len(self.multiway) == 1:
|
29 |
+
return self.multiway[0](hidden_states)
|
30 |
+
|
31 |
+
output_hidden_states = torch.empty_like(hidden_states)
|
32 |
+
|
33 |
+
for idx, subway in enumerate(self.multiway):
|
34 |
+
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
35 |
+
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
36 |
+
if hidden.numel():
|
37 |
+
output = subway(hidden)
|
38 |
+
if isinstance(output, tuple):
|
39 |
+
output = output[0]
|
40 |
+
output = output.squeeze(1)
|
41 |
+
output_hidden_states[local_indices] = output
|
42 |
+
|
43 |
+
return output_hidden_states.contiguous()
|
44 |
+
|
45 |
+
|
46 |
+
class LlamaAttention(nn.Module):
|
47 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
48 |
+
|
49 |
+
def __init__(self, config: LlamaConfig):
|
50 |
+
super().__init__()
|
51 |
+
self.config = config
|
52 |
+
self.hidden_size = config.hidden_size
|
53 |
+
self.num_heads = config.num_attention_heads
|
54 |
+
self.head_dim = self.hidden_size // self.num_heads
|
55 |
+
self.num_key_value_heads = config.num_key_value_heads
|
56 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
57 |
+
self.max_position_embeddings = config.max_position_embeddings
|
58 |
+
self.rope_theta = config.rope_theta
|
59 |
+
|
60 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
61 |
+
raise ValueError(
|
62 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
63 |
+
f" and `num_heads`: {self.num_heads})."
|
64 |
+
)
|
65 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
66 |
+
self.k_proj = MultiwayNetwork(module_provider=partial(
|
67 |
+
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
68 |
+
)
|
69 |
+
self.v_proj = MultiwayNetwork(module_provider=partial(
|
70 |
+
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
71 |
+
)
|
72 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
73 |
+
self._init_rope()
|
74 |
+
|
75 |
+
def _init_rope(self):
|
76 |
+
if self.config.rope_scaling is None:
|
77 |
+
self.rotary_emb = LlamaRotaryEmbedding(
|
78 |
+
self.head_dim,
|
79 |
+
max_position_embeddings=self.max_position_embeddings,
|
80 |
+
base=self.rope_theta,
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
scaling_type = self.config.rope_scaling["type"]
|
84 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
85 |
+
if scaling_type == "linear":
|
86 |
+
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
87 |
+
self.head_dim,
|
88 |
+
max_position_embeddings=self.max_position_embeddings,
|
89 |
+
scaling_factor=scaling_factor,
|
90 |
+
base=self.rope_theta,
|
91 |
+
)
|
92 |
+
elif scaling_type == "dynamic":
|
93 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
94 |
+
self.head_dim,
|
95 |
+
max_position_embeddings=self.max_position_embeddings,
|
96 |
+
scaling_factor=scaling_factor,
|
97 |
+
base=self.rope_theta,
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
101 |
+
|
102 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
103 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
104 |
+
|
105 |
+
def forward(
|
106 |
+
self,
|
107 |
+
hidden_states: torch.Tensor,
|
108 |
+
modality_indicators: torch.Tensor,
|
109 |
+
attention_mask: Optional[torch.Tensor] = None,
|
110 |
+
position_ids: Optional[torch.LongTensor] = None,
|
111 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
112 |
+
output_attentions: bool = False,
|
113 |
+
use_cache: bool = False,
|
114 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
115 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
116 |
+
bsz, q_len, _ = hidden_states.size()
|
117 |
+
|
118 |
+
query_states = self.q_proj(hidden_states, )
|
119 |
+
key_states = self.k_proj(hidden_states, modality_indicators)
|
120 |
+
value_states = self.v_proj(hidden_states, modality_indicators)
|
121 |
+
|
122 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
123 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
124 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
125 |
+
|
126 |
+
kv_seq_len = key_states.shape[-2]
|
127 |
+
if past_key_value is not None:
|
128 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
129 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
130 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
131 |
+
|
132 |
+
if past_key_value is not None:
|
133 |
+
# reuse k, v, self_attention
|
134 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
135 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
136 |
+
|
137 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
138 |
+
|
139 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
140 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
141 |
+
|
142 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
143 |
+
|
144 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
145 |
+
raise ValueError(
|
146 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
147 |
+
f" {attn_weights.size()}"
|
148 |
+
)
|
149 |
+
|
150 |
+
if attention_mask is not None:
|
151 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
152 |
+
raise ValueError(
|
153 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
154 |
+
)
|
155 |
+
attn_weights = attn_weights + attention_mask
|
156 |
+
|
157 |
+
# upcast attention to fp32
|
158 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
159 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
160 |
+
|
161 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
162 |
+
raise ValueError(
|
163 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
164 |
+
f" {attn_output.size()}"
|
165 |
+
)
|
166 |
+
|
167 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
168 |
+
|
169 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
170 |
+
|
171 |
+
attn_output = self.o_proj(attn_output)
|
172 |
+
|
173 |
+
if not output_attentions:
|
174 |
+
attn_weights = None
|
175 |
+
|
176 |
+
return attn_output, attn_weights, past_key_value
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
class LlamaDecoderLayer(nn.Module):
|
181 |
+
def __init__(self, config: LlamaConfig):
|
182 |
+
super().__init__()
|
183 |
+
self.hidden_size = config.hidden_size
|
184 |
+
self.self_attn = LlamaAttention(config=config)
|
185 |
+
self.mlp = LlamaMLP(config)
|
186 |
+
self.input_layernorm = MultiwayNetwork(module_provider=partial(
|
187 |
+
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
188 |
+
))
|
189 |
+
self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
|
190 |
+
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
191 |
+
))
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
hidden_states: torch.Tensor,
|
196 |
+
modality_indicators: torch.Tensor = None,
|
197 |
+
attention_mask: Optional[torch.Tensor] = None,
|
198 |
+
position_ids: Optional[torch.LongTensor] = None,
|
199 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
200 |
+
output_attentions: Optional[bool] = False,
|
201 |
+
use_cache: Optional[bool] = False,
|
202 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
203 |
+
"""
|
204 |
+
Args:
|
205 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
206 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
207 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
208 |
+
output_attentions (`bool`, *optional*):
|
209 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
210 |
+
returned tensors for more detail.
|
211 |
+
use_cache (`bool`, *optional*):
|
212 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
213 |
+
(see `past_key_values`).
|
214 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
215 |
+
"""
|
216 |
+
|
217 |
+
residual = hidden_states
|
218 |
+
|
219 |
+
hidden_states = self.input_layernorm(hidden_states, modality_indicators)
|
220 |
+
|
221 |
+
# Self Attention
|
222 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
223 |
+
hidden_states=hidden_states,
|
224 |
+
modality_indicators=modality_indicators,
|
225 |
+
attention_mask=attention_mask,
|
226 |
+
position_ids=position_ids,
|
227 |
+
past_key_value=past_key_value,
|
228 |
+
output_attentions=output_attentions,
|
229 |
+
use_cache=use_cache,
|
230 |
+
)
|
231 |
+
hidden_states = residual + hidden_states
|
232 |
+
|
233 |
+
# Fully Connected
|
234 |
+
residual = hidden_states
|
235 |
+
hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
|
236 |
+
hidden_states = self.mlp(hidden_states)
|
237 |
+
hidden_states = residual + hidden_states
|
238 |
+
|
239 |
+
outputs = (hidden_states,)
|
240 |
+
|
241 |
+
if output_attentions:
|
242 |
+
outputs += (self_attn_weights,)
|
243 |
+
|
244 |
+
if use_cache:
|
245 |
+
outputs += (present_key_value,)
|
246 |
+
|
247 |
+
return outputs
|
248 |
+
|
249 |
+
|
250 |
+
def model_forward(
|
251 |
+
self,
|
252 |
+
input_ids: torch.LongTensor = None,
|
253 |
+
modality_indicators: torch.Tensor = None,
|
254 |
+
attention_mask: Optional[torch.Tensor] = None,
|
255 |
+
position_ids: Optional[torch.LongTensor] = None,
|
256 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
257 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
258 |
+
use_cache: Optional[bool] = None,
|
259 |
+
output_attentions: Optional[bool] = None,
|
260 |
+
output_hidden_states: Optional[bool] = None,
|
261 |
+
return_dict: Optional[bool] = None,
|
262 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
263 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
264 |
+
output_hidden_states = (
|
265 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
266 |
+
)
|
267 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
268 |
+
|
269 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
270 |
+
|
271 |
+
# retrieve input_ids and inputs_embeds
|
272 |
+
if input_ids is not None and inputs_embeds is not None:
|
273 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
274 |
+
elif input_ids is not None:
|
275 |
+
batch_size, seq_length = input_ids.shape
|
276 |
+
elif inputs_embeds is not None:
|
277 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
278 |
+
else:
|
279 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
280 |
+
|
281 |
+
seq_length_with_past = seq_length
|
282 |
+
past_key_values_length = 0
|
283 |
+
|
284 |
+
if past_key_values is not None:
|
285 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
286 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
287 |
+
|
288 |
+
if position_ids is None:
|
289 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
290 |
+
position_ids = torch.arange(
|
291 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
292 |
+
)
|
293 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
294 |
+
else:
|
295 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
296 |
+
|
297 |
+
if inputs_embeds is None:
|
298 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
299 |
+
# embed positions
|
300 |
+
if attention_mask is None:
|
301 |
+
attention_mask = torch.ones(
|
302 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
303 |
+
)
|
304 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
305 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
306 |
+
)
|
307 |
+
|
308 |
+
hidden_states = inputs_embeds
|
309 |
+
|
310 |
+
if self.gradient_checkpointing and self.training:
|
311 |
+
if use_cache:
|
312 |
+
logger.warning_once(
|
313 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
314 |
+
)
|
315 |
+
use_cache = False
|
316 |
+
|
317 |
+
# decoder layers
|
318 |
+
all_hidden_states = () if output_hidden_states else None
|
319 |
+
all_self_attns = () if output_attentions else None
|
320 |
+
next_decoder_cache = () if use_cache else None
|
321 |
+
|
322 |
+
for idx, decoder_layer in enumerate(self.layers):
|
323 |
+
if output_hidden_states:
|
324 |
+
all_hidden_states += (hidden_states,)
|
325 |
+
|
326 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
327 |
+
|
328 |
+
if self.gradient_checkpointing and self.training:
|
329 |
+
|
330 |
+
def create_custom_forward(module):
|
331 |
+
def custom_forward(*inputs):
|
332 |
+
# None for past_key_value
|
333 |
+
return module(*inputs, past_key_value, output_attentions)
|
334 |
+
|
335 |
+
return custom_forward
|
336 |
+
|
337 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
338 |
+
create_custom_forward(decoder_layer),
|
339 |
+
hidden_states,
|
340 |
+
modality_indicators,
|
341 |
+
attention_mask,
|
342 |
+
position_ids,
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
layer_outputs = decoder_layer(
|
346 |
+
hidden_states,
|
347 |
+
modality_indicators=modality_indicators,
|
348 |
+
attention_mask=attention_mask,
|
349 |
+
position_ids=position_ids,
|
350 |
+
past_key_value=past_key_value,
|
351 |
+
output_attentions=output_attentions,
|
352 |
+
use_cache=use_cache,
|
353 |
+
)
|
354 |
+
|
355 |
+
hidden_states = layer_outputs[0]
|
356 |
+
|
357 |
+
if use_cache:
|
358 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
359 |
+
|
360 |
+
if output_attentions:
|
361 |
+
all_self_attns += (layer_outputs[1],)
|
362 |
+
|
363 |
+
hidden_states = self.norm(hidden_states)
|
364 |
+
|
365 |
+
# add hidden states from the last decoder layer
|
366 |
+
if output_hidden_states:
|
367 |
+
all_hidden_states += (hidden_states,)
|
368 |
+
|
369 |
+
next_cache = next_decoder_cache if use_cache else None
|
370 |
+
if not return_dict:
|
371 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
372 |
+
return BaseModelOutputWithPast(
|
373 |
+
last_hidden_state=hidden_states,
|
374 |
+
past_key_values=next_cache,
|
375 |
+
hidden_states=all_hidden_states,
|
376 |
+
attentions=all_self_attns,
|
377 |
+
)
|
378 |
+
|
379 |
+
|
380 |
+
def causal_model_forward(
|
381 |
+
self,
|
382 |
+
input_ids: torch.LongTensor = None,
|
383 |
+
modality_indicators: torch.Tensor = None,
|
384 |
+
attention_mask: Optional[torch.Tensor] = None,
|
385 |
+
position_ids: Optional[torch.LongTensor] = None,
|
386 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
387 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
388 |
+
labels: Optional[torch.LongTensor] = None,
|
389 |
+
use_cache: Optional[bool] = None,
|
390 |
+
output_attentions: Optional[bool] = None,
|
391 |
+
output_hidden_states: Optional[bool] = None,
|
392 |
+
return_dict: Optional[bool] = None,
|
393 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
394 |
+
r"""
|
395 |
+
Args:
|
396 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
397 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
398 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
399 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
|
403 |
+
Example:
|
404 |
+
|
405 |
+
```python
|
406 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
407 |
+
|
408 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
409 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
410 |
+
|
411 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
412 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
413 |
+
|
414 |
+
>>> # Generate
|
415 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
416 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
417 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
418 |
+
```"""
|
419 |
+
|
420 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
421 |
+
output_hidden_states = (
|
422 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
423 |
+
)
|
424 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
425 |
+
|
426 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
427 |
+
outputs = self.model(
|
428 |
+
input_ids=input_ids,
|
429 |
+
modality_indicators=modality_indicators,
|
430 |
+
attention_mask=attention_mask,
|
431 |
+
position_ids=position_ids,
|
432 |
+
past_key_values=past_key_values,
|
433 |
+
inputs_embeds=inputs_embeds,
|
434 |
+
use_cache=use_cache,
|
435 |
+
output_attentions=output_attentions,
|
436 |
+
output_hidden_states=output_hidden_states,
|
437 |
+
return_dict=return_dict,
|
438 |
+
)
|
439 |
+
|
440 |
+
hidden_states = outputs[0]
|
441 |
+
if self.config.pretraining_tp > 1:
|
442 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
443 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
444 |
+
logits = torch.cat(logits, dim=-1)
|
445 |
+
else:
|
446 |
+
logits = self.lm_head(hidden_states)
|
447 |
+
logits = logits.float()
|
448 |
+
|
449 |
+
loss = None
|
450 |
+
if labels is not None:
|
451 |
+
# Shift so that tokens < n predict n
|
452 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
453 |
+
shift_labels = labels[..., 1:].contiguous()
|
454 |
+
# Flatten the tokens
|
455 |
+
loss_fct = CrossEntropyLoss()
|
456 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
457 |
+
shift_labels = shift_labels.view(-1)
|
458 |
+
# Enable model parallelism
|
459 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
460 |
+
loss = loss_fct(shift_logits, shift_labels)
|
461 |
+
|
462 |
+
if not return_dict:
|
463 |
+
output = (logits,) + outputs[1:]
|
464 |
+
return (loss,) + output if loss is not None else output
|
465 |
+
|
466 |
+
return CausalLMOutputWithPast(
|
467 |
+
loss=loss,
|
468 |
+
logits=logits,
|
469 |
+
past_key_values=outputs.past_key_values,
|
470 |
+
hidden_states=outputs.hidden_states,
|
471 |
+
attentions=outputs.attentions,
|
472 |
+
)
|
473 |
+
|
474 |
+
def replace_llama_modality_adaptive():
|
475 |
+
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
|
476 |
+
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
477 |
+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
478 |
+
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
|
479 |
+
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
|
480 |
+
|
481 |
+
|
482 |
+
if __name__ == "__main__":
|
483 |
+
replace_llama_modality_adaptive()
|
484 |
+
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
|
485 |
+
model = transformers.LlamaForCausalLM(config)
|
486 |
+
print(model)
|
mplug_docowl/model/modeling_mplug_docowl.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
|
23 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
24 |
+
|
25 |
+
from .configuration_mplug_docowl import (MPLUGDocOwlConfig, MplugOwlVisionConfig, MplugDocOwlHReducerConfig)
|
26 |
+
from .visual_encoder import MplugOwlVisionModel, MplugDocOwlHReducerModel
|
27 |
+
from .modeling_llama2 import replace_llama_modality_adaptive
|
28 |
+
from mplug_docowl.constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX
|
29 |
+
from icecream import ic
|
30 |
+
|
31 |
+
class MPLUGDocOwlMetaModel:
|
32 |
+
def __init__(self, config):
|
33 |
+
super(MPLUGDocOwlMetaModel, self).__init__(config)
|
34 |
+
self.vision_model = MplugOwlVisionModel(
|
35 |
+
MplugOwlVisionConfig(**config.visual_config["visual_model"])
|
36 |
+
)
|
37 |
+
|
38 |
+
self.vision2text = MplugDocOwlHReducerModel(
|
39 |
+
MplugDocOwlHReducerConfig(**config.visual_config["visual_hreducer"]), config.hidden_size
|
40 |
+
)
|
41 |
+
|
42 |
+
def get_vision_tower(self):
|
43 |
+
vision_model = getattr(self, 'vision_model', None)
|
44 |
+
if type(vision_model) is list:
|
45 |
+
vision_model = vision_model[0]
|
46 |
+
return vision_model
|
47 |
+
|
48 |
+
def get_vision2text(self):
|
49 |
+
vision2text = getattr(self, 'vision2text', None)
|
50 |
+
if type(vision2text) is list:
|
51 |
+
vision2text = vision2text[0]
|
52 |
+
return vision2text
|
53 |
+
|
54 |
+
class MPLUGDocOwlMetaForCausalLM(ABC):
|
55 |
+
@abstractmethod
|
56 |
+
def get_model(self):
|
57 |
+
pass
|
58 |
+
|
59 |
+
def encode_images(self, images, patch_positions):
|
60 |
+
image_features = self.get_model().vision_model(images).last_hidden_state
|
61 |
+
image_features = self.get_model().vision2text(encoder_hidden_states=image_features)
|
62 |
+
return image_features
|
63 |
+
|
64 |
+
def prepare_inputs_labels_for_multimodal(
|
65 |
+
self, input_ids, attention_mask, past_key_values, labels, images, patch_positions
|
66 |
+
):
|
67 |
+
if images is None or input_ids.shape[1] == 1:
|
68 |
+
if past_key_values is not None and images is not None and input_ids.shape[1] == 1:
|
69 |
+
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
70 |
+
multiway_indices = torch.zeros_like(input_ids).long().to(self.device)
|
71 |
+
return input_ids, multiway_indices, attention_mask, past_key_values, None, labels
|
72 |
+
|
73 |
+
if type(images) is list or images.ndim == 5:
|
74 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
75 |
+
image_features = self.encode_images(concat_images, patch_positions)
|
76 |
+
split_sizes = [image.shape[0] for image in images]
|
77 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
78 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
79 |
+
else:
|
80 |
+
image_features = self.encode_images(images, patch_positions) # Sum(Crop Image Number) x L x d
|
81 |
+
|
82 |
+
new_input_embeds = []
|
83 |
+
new_modality_indicators = []
|
84 |
+
new_labels = [] if labels is not None else None
|
85 |
+
cur_image_idx = 0
|
86 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
87 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
88 |
+
# multimodal LLM, but the current sample is not multimodal
|
89 |
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
90 |
+
half_len = cur_input_ids.shape[0] // 2
|
91 |
+
cur_image_features = image_features[cur_image_idx]
|
92 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
93 |
+
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
94 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
95 |
+
new_input_embeds.append(cur_input_embeds)
|
96 |
+
|
97 |
+
cur_modality_indicators = torch.zeros(len(cur_input_embeds)).long().to(self.device)
|
98 |
+
new_modality_indicators.append(cur_modality_indicators)
|
99 |
+
if labels is not None:
|
100 |
+
new_labels.append(labels[batch_idx])
|
101 |
+
cur_image_idx += 1
|
102 |
+
continue
|
103 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
104 |
+
cur_new_input_embeds = []
|
105 |
+
cur_modality_indicators = []
|
106 |
+
if labels is not None:
|
107 |
+
cur_labels = labels[batch_idx]
|
108 |
+
cur_new_labels = []
|
109 |
+
assert cur_labels.shape == cur_input_ids.shape
|
110 |
+
while image_token_indices.numel() > 0:
|
111 |
+
cur_image_features = image_features[cur_image_idx]
|
112 |
+
image_token_start = image_token_indices[0]
|
113 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
114 |
+
cur_new_input_embeds.append(cur_image_features)
|
115 |
+
|
116 |
+
# Add modality indicator
|
117 |
+
assert image_token_start == len(cur_input_ids[:image_token_start])
|
118 |
+
cur_modality_indicators.append(torch.zeros(len(cur_input_ids[:image_token_start])).long())
|
119 |
+
cur_modality_indicators.append(torch.ones(len(cur_image_features)).long())
|
120 |
+
|
121 |
+
if labels is not None:
|
122 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
123 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
124 |
+
cur_labels = cur_labels[image_token_start+1:]
|
125 |
+
cur_image_idx += 1
|
126 |
+
cur_input_ids = cur_input_ids[image_token_start+1:]
|
127 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
128 |
+
if cur_input_ids.numel() > 0:
|
129 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
130 |
+
cur_modality_indicators.append(torch.zeros(len(cur_input_ids)).long())
|
131 |
+
if labels is not None:
|
132 |
+
cur_new_labels.append(cur_labels)
|
133 |
+
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
134 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
135 |
+
new_input_embeds.append(cur_new_input_embeds)
|
136 |
+
|
137 |
+
# Modality
|
138 |
+
cur_modality_indicators = [x.to(device=self.device) for x in cur_modality_indicators]
|
139 |
+
cur_modality_indicators = torch.cat(cur_modality_indicators, dim=0)
|
140 |
+
new_modality_indicators.append(cur_modality_indicators)
|
141 |
+
|
142 |
+
|
143 |
+
if labels is not None:
|
144 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
145 |
+
new_labels.append(cur_new_labels)
|
146 |
+
|
147 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
148 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
149 |
+
|
150 |
+
# Embedding
|
151 |
+
new_input_embeds_align = []
|
152 |
+
for cur_new_embed in new_input_embeds:
|
153 |
+
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
154 |
+
new_input_embeds_align.append(cur_new_embed)
|
155 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
156 |
+
|
157 |
+
# Modality
|
158 |
+
new_modality_indicators_align = []
|
159 |
+
for cur_modality_indicator in new_modality_indicators:
|
160 |
+
cur_new_embed = torch.cat((cur_modality_indicator, torch.zeros(max_len - cur_modality_indicator.shape[0], dtype=cur_modality_indicator.dtype, device=cur_modality_indicator.device)), dim=0)
|
161 |
+
new_modality_indicators_align.append(cur_new_embed)
|
162 |
+
new_modality_indicators = torch.stack(new_modality_indicators_align, dim=0)
|
163 |
+
|
164 |
+
# Label
|
165 |
+
if labels is not None:
|
166 |
+
new_labels_align = []
|
167 |
+
_new_labels = new_labels
|
168 |
+
for cur_new_label in new_labels:
|
169 |
+
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
170 |
+
new_labels_align.append(cur_new_label)
|
171 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
172 |
+
|
173 |
+
# Attention Mask
|
174 |
+
if attention_mask is not None:
|
175 |
+
new_attention_mask = []
|
176 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
177 |
+
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
178 |
+
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
179 |
+
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
180 |
+
new_attention_mask.append(cur_new_attention_mask)
|
181 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
182 |
+
assert attention_mask.shape == new_labels.shape
|
183 |
+
else:
|
184 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
185 |
+
new_modality_indicators = torch.stack(new_modality_indicators, dim=0)
|
186 |
+
if labels is not None:
|
187 |
+
new_labels = torch.stack(new_labels, dim=0)
|
188 |
+
|
189 |
+
if attention_mask is not None:
|
190 |
+
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
191 |
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
192 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
193 |
+
return None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
class MPLUGDocOwlLlamaModel(MPLUGDocOwlMetaModel, LlamaModel):
|
198 |
+
config_class = MPLUGDocOwlConfig
|
199 |
+
|
200 |
+
def __init__(self, config: MPLUGDocOwlConfig):
|
201 |
+
super(MPLUGDocOwlLlamaModel, self).__init__(config)
|
202 |
+
|
203 |
+
|
204 |
+
class MPLUGDocOwlLlamaForCausalLM(LlamaForCausalLM, MPLUGDocOwlMetaForCausalLM):
|
205 |
+
config_class = MPLUGDocOwlConfig
|
206 |
+
|
207 |
+
def __init__(self, config):
|
208 |
+
super(LlamaForCausalLM, self).__init__(config)
|
209 |
+
self.model = MPLUGDocOwlLlamaModel(config)
|
210 |
+
|
211 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
212 |
+
|
213 |
+
# Initialize weights and apply final processing
|
214 |
+
self.post_init()
|
215 |
+
|
216 |
+
def get_model(self):
|
217 |
+
return self.model
|
218 |
+
|
219 |
+
def forward(
|
220 |
+
self,
|
221 |
+
input_ids: torch.LongTensor = None,
|
222 |
+
# modality_indicators: torch.LongTensor = None,
|
223 |
+
attention_mask: Optional[torch.Tensor] = None,
|
224 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
225 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
226 |
+
labels: Optional[torch.LongTensor] = None,
|
227 |
+
use_cache: Optional[bool] = None,
|
228 |
+
output_attentions: Optional[bool] = None,
|
229 |
+
output_hidden_states: Optional[bool] = None,
|
230 |
+
images: Optional[torch.FloatTensor] = None,
|
231 |
+
patch_positions: Optional[torch.LongTensor] = None,
|
232 |
+
return_dict: Optional[bool] = None,
|
233 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
234 |
+
|
235 |
+
# print('modeling_mplug_docow2.py patch_positions:', patch_positions)
|
236 |
+
|
237 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
238 |
+
output_hidden_states = (
|
239 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
240 |
+
)
|
241 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
242 |
+
input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels = \
|
243 |
+
self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, patch_positions)
|
244 |
+
|
245 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
246 |
+
outputs = self.model(
|
247 |
+
input_ids=input_ids,
|
248 |
+
modality_indicators=modality_indicators,
|
249 |
+
attention_mask=attention_mask,
|
250 |
+
past_key_values=past_key_values,
|
251 |
+
inputs_embeds=inputs_embeds,
|
252 |
+
use_cache=use_cache,
|
253 |
+
output_attentions=output_attentions,
|
254 |
+
output_hidden_states=output_hidden_states,
|
255 |
+
return_dict=return_dict
|
256 |
+
)
|
257 |
+
|
258 |
+
hidden_states = outputs[0]
|
259 |
+
logits = self.lm_head(hidden_states)
|
260 |
+
|
261 |
+
loss = None
|
262 |
+
if labels is not None:
|
263 |
+
# Shift so that tokens < n predict n
|
264 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
265 |
+
shift_labels = labels[..., 1:].contiguous()
|
266 |
+
# Flatten the tokens
|
267 |
+
loss_fct = CrossEntropyLoss()
|
268 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
269 |
+
shift_labels = shift_labels.view(-1)
|
270 |
+
# Enable model/pipeline parallelism
|
271 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
272 |
+
loss = loss_fct(shift_logits, shift_labels)
|
273 |
+
|
274 |
+
if not return_dict:
|
275 |
+
output = (logits,) + outputs[1:]
|
276 |
+
return (loss,) + output if loss is not None else output
|
277 |
+
|
278 |
+
return CausalLMOutputWithPast(
|
279 |
+
loss=loss,
|
280 |
+
logits=logits,
|
281 |
+
past_key_values=outputs.past_key_values,
|
282 |
+
hidden_states=outputs.hidden_states,
|
283 |
+
attentions=outputs.attentions,
|
284 |
+
)
|
285 |
+
|
286 |
+
def prepare_inputs_for_generation(
|
287 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
288 |
+
):
|
289 |
+
if past_key_values:
|
290 |
+
input_ids = input_ids[:, -1:]
|
291 |
+
|
292 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
293 |
+
if inputs_embeds is not None and past_key_values is None:
|
294 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
295 |
+
else:
|
296 |
+
model_inputs = {"input_ids": input_ids}
|
297 |
+
|
298 |
+
model_inputs.update(
|
299 |
+
{
|
300 |
+
"past_key_values": past_key_values,
|
301 |
+
"use_cache": kwargs.get("use_cache"),
|
302 |
+
"attention_mask": attention_mask,
|
303 |
+
"images": kwargs.get("images", None),
|
304 |
+
"patch_positions": kwargs.get("patch_positions", None),
|
305 |
+
}
|
306 |
+
)
|
307 |
+
return model_inputs
|
308 |
+
|
309 |
+
AutoConfig.register("mplug_docowl", MPLUGDocOwlConfig)
|
310 |
+
AutoModelForCausalLM.register(MPLUGDocOwlConfig, MPLUGDocOwlLlamaForCausalLM)
|
311 |
+
|
312 |
+
replace_llama_modality_adaptive()
|
313 |
+
|
mplug_docowl/model/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoConfig
|
2 |
+
|
3 |
+
|
4 |
+
def auto_upgrade(config):
|
5 |
+
cfg = AutoConfig.from_pretrained(config)
|
6 |
+
if 'mplug_owl2' in config and 'mplug_owl2' not in cfg.model_type:
|
7 |
+
assert cfg.model_type == 'mplug_owl2'
|
8 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
9 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
10 |
+
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
11 |
+
if confirm.lower() in ["y", "yes"]:
|
12 |
+
print("Upgrading checkpoint...")
|
13 |
+
assert len(cfg.architectures) == 1
|
14 |
+
setattr(cfg.__class__, "model_type", "mplug_owl2")
|
15 |
+
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
16 |
+
cfg.save_pretrained(config)
|
17 |
+
print("Checkpoint upgraded.")
|
18 |
+
else:
|
19 |
+
print("Checkpoint upgrade aborted.")
|
20 |
+
exit(1)
|
mplug_docowl/model/visual_encoder.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Optional, Tuple, Union
|
3 |
+
|
4 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions
|
5 |
+
from transformers.modeling_utils import PreTrainedModel
|
6 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from icecream import ic
|
13 |
+
import einops
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
def get_abs_pos(abs_pos, tgt_size):
|
17 |
+
# abs_pos: L, C
|
18 |
+
# tgt_size: M
|
19 |
+
# return: M, C
|
20 |
+
src_size = int(math.sqrt(abs_pos.size(0)))
|
21 |
+
tgt_size = int(math.sqrt(tgt_size))
|
22 |
+
dtype = abs_pos.dtype
|
23 |
+
|
24 |
+
if src_size != tgt_size:
|
25 |
+
return F.interpolate(
|
26 |
+
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
27 |
+
size=(tgt_size, tgt_size),
|
28 |
+
mode="bicubic",
|
29 |
+
align_corners=False,
|
30 |
+
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
31 |
+
else:
|
32 |
+
return abs_pos
|
33 |
+
|
34 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
35 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
36 |
+
"""
|
37 |
+
grid_size: int of the grid height and width
|
38 |
+
return:
|
39 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
40 |
+
"""
|
41 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
42 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
43 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
44 |
+
grid = np.stack(grid, axis=0)
|
45 |
+
|
46 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
47 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
48 |
+
if cls_token:
|
49 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
50 |
+
return pos_embed
|
51 |
+
|
52 |
+
|
53 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
54 |
+
assert embed_dim % 2 == 0
|
55 |
+
|
56 |
+
# use half of dimensions to encode grid_h
|
57 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
58 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
59 |
+
|
60 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
61 |
+
return emb
|
62 |
+
|
63 |
+
|
64 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
65 |
+
"""
|
66 |
+
embed_dim: output dimension for each position
|
67 |
+
pos: a list of positions to be encoded: size (M,)
|
68 |
+
out: (M, D)
|
69 |
+
"""
|
70 |
+
assert embed_dim % 2 == 0
|
71 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
72 |
+
omega /= embed_dim / 2.
|
73 |
+
omega = 1. / 10000**omega # (D/2,)
|
74 |
+
|
75 |
+
pos = pos.reshape(-1) # (M,)
|
76 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
77 |
+
|
78 |
+
emb_sin = np.sin(out) # (M, D/2)
|
79 |
+
emb_cos = np.cos(out) # (M, D/2)
|
80 |
+
|
81 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
82 |
+
return emb
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
class MplugOwlVisionEmbeddings(nn.Module):
|
87 |
+
def __init__(self, config):
|
88 |
+
super().__init__()
|
89 |
+
self.config = config
|
90 |
+
self.hidden_size = config.hidden_size
|
91 |
+
self.image_size = config.image_size
|
92 |
+
self.patch_size = config.patch_size
|
93 |
+
|
94 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
95 |
+
|
96 |
+
self.patch_embed = nn.Conv2d(
|
97 |
+
in_channels=3,
|
98 |
+
out_channels=self.hidden_size,
|
99 |
+
kernel_size=self.patch_size,
|
100 |
+
stride=self.patch_size,
|
101 |
+
bias=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
105 |
+
|
106 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
|
107 |
+
|
108 |
+
self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
109 |
+
|
110 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
111 |
+
batch_size = pixel_values.size(0)
|
112 |
+
image_embeds = self.patch_embed(pixel_values)
|
113 |
+
image_embeds = image_embeds.flatten(2).transpose(1, 2)
|
114 |
+
|
115 |
+
class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype)
|
116 |
+
embeddings = torch.cat([class_embeds, image_embeds], dim=1)
|
117 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
|
118 |
+
embeddings = self.pre_layernorm(embeddings)
|
119 |
+
return embeddings
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
class MplugOwlVisionAttention(nn.Module):
|
124 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
125 |
+
|
126 |
+
def __init__(self, config):
|
127 |
+
super().__init__()
|
128 |
+
self.config = config
|
129 |
+
self.hidden_size = config.hidden_size
|
130 |
+
self.num_heads = config.num_attention_heads
|
131 |
+
self.head_dim = self.hidden_size // self.num_heads
|
132 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
133 |
+
raise ValueError(
|
134 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
135 |
+
f" {self.num_heads})."
|
136 |
+
)
|
137 |
+
self.scale = self.head_dim**-0.5
|
138 |
+
self.dropout = nn.Dropout(config.attention_dropout)
|
139 |
+
|
140 |
+
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size)
|
141 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
142 |
+
|
143 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
144 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
145 |
+
|
146 |
+
def forward(
|
147 |
+
self,
|
148 |
+
hidden_states: torch.Tensor,
|
149 |
+
head_mask: Optional[torch.Tensor] = None,
|
150 |
+
output_attentions: Optional[bool] = False,
|
151 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
152 |
+
"""Input shape: Batch x Time x Channel"""
|
153 |
+
|
154 |
+
bsz, seq_len, embed_dim = hidden_states.size()
|
155 |
+
|
156 |
+
mixed_qkv = self.query_key_value(hidden_states)
|
157 |
+
|
158 |
+
mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute(
|
159 |
+
3, 0, 2, 1, 4
|
160 |
+
) # [3, b, np, sq, hn]
|
161 |
+
query_states, key_states, value_states = (
|
162 |
+
mixed_qkv[0],
|
163 |
+
mixed_qkv[1],
|
164 |
+
mixed_qkv[2],
|
165 |
+
)
|
166 |
+
# if self.config.use_flash_attn and flash_attn_func is not None:
|
167 |
+
if False:
|
168 |
+
# [b*sq, np, hn]
|
169 |
+
query_states = query_states.permute(0, 2, 1, 3).contiguous()
|
170 |
+
query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1)
|
171 |
+
|
172 |
+
key_states = key_states.permute(0, 2, 1, 3).contiguous()
|
173 |
+
key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1)
|
174 |
+
|
175 |
+
value_states = value_states.permute(0, 2, 1, 3).contiguous()
|
176 |
+
value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1)
|
177 |
+
|
178 |
+
cu_seqlens = torch.arange(
|
179 |
+
0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device
|
180 |
+
)
|
181 |
+
|
182 |
+
context_layer = flash_attn_func(
|
183 |
+
query_states,
|
184 |
+
key_states,
|
185 |
+
value_states,
|
186 |
+
cu_seqlens,
|
187 |
+
cu_seqlens,
|
188 |
+
seq_len,
|
189 |
+
seq_len,
|
190 |
+
self.dropout if self.training else 0.0,
|
191 |
+
softmax_scale=self.scale,
|
192 |
+
causal=False,
|
193 |
+
return_attn_probs=False,
|
194 |
+
)
|
195 |
+
# [b*sq, np, hn] => [b, sq, np, hn]
|
196 |
+
context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2))
|
197 |
+
else:
|
198 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
199 |
+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
200 |
+
|
201 |
+
attention_scores = attention_scores * self.scale
|
202 |
+
|
203 |
+
# Normalize the attention scores to probabilities.
|
204 |
+
attention_probs = torch.softmax(attention_scores, dim=-1)
|
205 |
+
|
206 |
+
# This is actually dropping out entire tokens to attend to, which might
|
207 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
208 |
+
attention_probs = self.dropout(attention_probs)
|
209 |
+
|
210 |
+
# Mask heads if we want to
|
211 |
+
if head_mask is not None:
|
212 |
+
attention_probs = attention_probs * head_mask
|
213 |
+
|
214 |
+
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
215 |
+
|
216 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
|
217 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
218 |
+
|
219 |
+
output = self.dense(context_layer)
|
220 |
+
|
221 |
+
outputs = (output, attention_probs) if output_attentions else (output, None)
|
222 |
+
|
223 |
+
return outputs
|
224 |
+
|
225 |
+
|
226 |
+
class QuickGELU(nn.Module):
|
227 |
+
def forward(self, x: torch.Tensor):
|
228 |
+
return x * torch.sigmoid(1.702 * x)
|
229 |
+
|
230 |
+
|
231 |
+
class MplugOwlMLP(nn.Module):
|
232 |
+
def __init__(self, config):
|
233 |
+
super().__init__()
|
234 |
+
self.config = config
|
235 |
+
self.activation_fn = QuickGELU()
|
236 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
237 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
238 |
+
|
239 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
240 |
+
hidden_states = self.fc1(hidden_states)
|
241 |
+
hidden_states = self.activation_fn(hidden_states)
|
242 |
+
hidden_states = self.fc2(hidden_states)
|
243 |
+
return hidden_states
|
244 |
+
|
245 |
+
|
246 |
+
class MplugOwlVisionEncoderLayer(nn.Module):
|
247 |
+
def __init__(self, config):
|
248 |
+
super().__init__()
|
249 |
+
self.hidden_size = config.hidden_size
|
250 |
+
self.self_attn = MplugOwlVisionAttention(config)
|
251 |
+
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
252 |
+
self.mlp = MplugOwlMLP(config)
|
253 |
+
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
hidden_states: torch.Tensor,
|
258 |
+
attention_mask: torch.Tensor,
|
259 |
+
output_attentions: Optional[bool] = False,
|
260 |
+
) -> Tuple[torch.FloatTensor]:
|
261 |
+
"""
|
262 |
+
Args:
|
263 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
264 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
265 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
266 |
+
`(config.encoder_attention_heads,)`.
|
267 |
+
output_attentions (`bool`, *optional*):
|
268 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
269 |
+
returned tensors for more detail.
|
270 |
+
"""
|
271 |
+
residual = hidden_states
|
272 |
+
|
273 |
+
hidden_states = self.input_layernorm(hidden_states)
|
274 |
+
hidden_states, attn_weights = self.self_attn(
|
275 |
+
hidden_states=hidden_states,
|
276 |
+
head_mask=attention_mask,
|
277 |
+
output_attentions=output_attentions,
|
278 |
+
)
|
279 |
+
hidden_states = hidden_states + residual
|
280 |
+
residual = hidden_states
|
281 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
282 |
+
hidden_states = self.mlp(hidden_states)
|
283 |
+
|
284 |
+
hidden_states = hidden_states + residual
|
285 |
+
|
286 |
+
outputs = (hidden_states,)
|
287 |
+
|
288 |
+
if output_attentions:
|
289 |
+
outputs += (attn_weights,)
|
290 |
+
|
291 |
+
return outputs
|
292 |
+
|
293 |
+
|
294 |
+
class MplugOwlVisionEncoder(nn.Module):
|
295 |
+
"""
|
296 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
297 |
+
[`MplugOwlVisionEncoderLayer`].
|
298 |
+
|
299 |
+
Args:
|
300 |
+
config (`MplugOwlVisionConfig`):
|
301 |
+
The corresponding vision configuration for the `MplugOwlEncoder`.
|
302 |
+
"""
|
303 |
+
|
304 |
+
def __init__(self, config):
|
305 |
+
super().__init__()
|
306 |
+
self.config = config
|
307 |
+
self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
308 |
+
self.gradient_checkpointing = True
|
309 |
+
|
310 |
+
def forward(
|
311 |
+
self,
|
312 |
+
inputs_embeds,
|
313 |
+
attention_mask: Optional[torch.Tensor] = None,
|
314 |
+
output_attentions: Optional[bool] = None,
|
315 |
+
output_hidden_states: Optional[bool] = None,
|
316 |
+
return_dict: Optional[bool] = None,
|
317 |
+
) -> Union[Tuple, BaseModelOutput]:
|
318 |
+
r"""
|
319 |
+
Args:
|
320 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
321 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
322 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
323 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
324 |
+
|
325 |
+
- 1 for tokens that are **not masked**,
|
326 |
+
- 0 for tokens that are **masked**.
|
327 |
+
|
328 |
+
[What are attention masks?](../glossary#attention-mask)
|
329 |
+
output_attentions (`bool`, *optional*):
|
330 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
331 |
+
returned tensors for more detail.
|
332 |
+
output_hidden_states (`bool`, *optional*):
|
333 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
334 |
+
for more detail.
|
335 |
+
return_dict (`bool`, *optional*):
|
336 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
337 |
+
"""
|
338 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
339 |
+
output_hidden_states = (
|
340 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
341 |
+
)
|
342 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
343 |
+
|
344 |
+
encoder_states = () if output_hidden_states else None
|
345 |
+
all_attentions = () if output_attentions else None
|
346 |
+
|
347 |
+
hidden_states = inputs_embeds
|
348 |
+
for idx, encoder_layer in enumerate(self.layers):
|
349 |
+
if output_hidden_states:
|
350 |
+
encoder_states = encoder_states + (hidden_states,)
|
351 |
+
if self.gradient_checkpointing and self.training:
|
352 |
+
|
353 |
+
def create_custom_forward(module):
|
354 |
+
def custom_forward(*inputs):
|
355 |
+
return module(*inputs, output_attentions)
|
356 |
+
|
357 |
+
return custom_forward
|
358 |
+
|
359 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
360 |
+
create_custom_forward(encoder_layer),
|
361 |
+
hidden_states,
|
362 |
+
attention_mask,
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
layer_outputs = encoder_layer(
|
366 |
+
hidden_states,
|
367 |
+
attention_mask,
|
368 |
+
output_attentions=output_attentions,
|
369 |
+
)
|
370 |
+
|
371 |
+
hidden_states = layer_outputs[0]
|
372 |
+
|
373 |
+
if output_attentions:
|
374 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
375 |
+
|
376 |
+
if output_hidden_states:
|
377 |
+
encoder_states = encoder_states + (hidden_states,)
|
378 |
+
|
379 |
+
if not return_dict:
|
380 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
381 |
+
return BaseModelOutput(
|
382 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
+
class MplugOwlVisionModel(PreTrainedModel):
|
387 |
+
main_input_name = "pixel_values"
|
388 |
+
|
389 |
+
def __init__(self, config):
|
390 |
+
super().__init__(config)
|
391 |
+
self.config = config
|
392 |
+
self.hidden_size = config.hidden_size
|
393 |
+
|
394 |
+
self.embeddings = MplugOwlVisionEmbeddings(config)
|
395 |
+
self.encoder = MplugOwlVisionEncoder(config)
|
396 |
+
self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
397 |
+
|
398 |
+
self.post_init()
|
399 |
+
|
400 |
+
|
401 |
+
def forward(
|
402 |
+
self,
|
403 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
404 |
+
output_attentions: Optional[bool] = None,
|
405 |
+
output_hidden_states: Optional[bool] = None,
|
406 |
+
return_dict: Optional[bool] = None,
|
407 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
408 |
+
r"""
|
409 |
+
Returns:
|
410 |
+
|
411 |
+
"""
|
412 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
413 |
+
output_hidden_states = (
|
414 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
415 |
+
)
|
416 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
417 |
+
|
418 |
+
if pixel_values is None:
|
419 |
+
raise ValueError("You have to specify pixel_values")
|
420 |
+
|
421 |
+
hidden_states = self.embeddings(pixel_values)
|
422 |
+
|
423 |
+
encoder_outputs = self.encoder(
|
424 |
+
inputs_embeds=hidden_states,
|
425 |
+
output_attentions=output_attentions,
|
426 |
+
output_hidden_states=output_hidden_states,
|
427 |
+
return_dict=return_dict,
|
428 |
+
)
|
429 |
+
|
430 |
+
last_hidden_state = encoder_outputs[0]
|
431 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
432 |
+
|
433 |
+
pooled_output = last_hidden_state[:, 0, :]
|
434 |
+
pooled_output = self.post_layernorm(pooled_output)
|
435 |
+
|
436 |
+
if not return_dict:
|
437 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
438 |
+
|
439 |
+
return BaseModelOutputWithPooling(
|
440 |
+
last_hidden_state=last_hidden_state,
|
441 |
+
pooler_output=pooled_output,
|
442 |
+
hidden_states=encoder_outputs.hidden_states,
|
443 |
+
attentions=encoder_outputs.attentions,
|
444 |
+
)
|
445 |
+
|
446 |
+
def get_input_embeddings(self):
|
447 |
+
return self.embeddings
|
448 |
+
|
449 |
+
|
450 |
+
class MplugDocOwlHReducerModel(PreTrainedModel):
|
451 |
+
def __init__(self, config, language_hidden_size):
|
452 |
+
super().__init__(config)
|
453 |
+
self.config = config
|
454 |
+
self.ln_q = torch.nn.LayerNorm(self.config.hidden_size, eps=1e-6)
|
455 |
+
self.conv_shape = (int(self.config.conv_shape.split('x')[0]), int(self.config.conv_shape.split('x')[1])) #
|
456 |
+
self.conv_patch=self.conv_shape[0]*self.conv_shape[1]
|
457 |
+
## feature interaction with a conv layer
|
458 |
+
self.reducer_before = torch.nn.Sequential(
|
459 |
+
nn.Conv2d(self.config.hidden_size, self.conv_patch*self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True),
|
460 |
+
nn.GELU()
|
461 |
+
)
|
462 |
+
## reduce visual feature length with a conv layer
|
463 |
+
self.reducer = nn.Conv2d(self.config.hidden_size, self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True)
|
464 |
+
## align visual features with language embedding with fc
|
465 |
+
self.visual_fc = torch.nn.Linear(self.config.hidden_size, language_hidden_size)
|
466 |
+
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
|
467 |
+
|
468 |
+
self.post_init()
|
469 |
+
|
470 |
+
def forward(
|
471 |
+
self,
|
472 |
+
encoder_hidden_states=None
|
473 |
+
):
|
474 |
+
r"""
|
475 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
476 |
+
batch_size is the number of all images (global+crop) in a batch
|
477 |
+
Sequence of hidden-states at the output of the last layer of the encoder.
|
478 |
+
"""
|
479 |
+
encoder_hidden_states = encoder_hidden_states[:,1:,:] # remove the first cls token
|
480 |
+
B, L, C = encoder_hidden_states.shape # B, 1024=(448/14)^2, 1024
|
481 |
+
|
482 |
+
## feature interaction with a conv layer
|
483 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, 'B (H W) D -> B D H W', H=int(math.sqrt(L)))
|
484 |
+
hidden_states = self.reducer_before(encoder_hidden_states) # B 4D H W/4
|
485 |
+
## reduce seq length with a conv layer
|
486 |
+
"""hidden_states = hidden_states.flatten(2).transpose(1, 2) # B 4D H W/4 -> B 4D H*W/4 -> B H*W/4 4D
|
487 |
+
hidden_states = rearrange(hidden_states, 'B L (X D) -> B (L X) D', X=self.conv_patch) # B (H W) D
|
488 |
+
hidden_states = rearrange(hidden_states, 'B (H W) D -> B D H W', H=int(math.sqrt(L))) # B D H W """
|
489 |
+
hidden_states = rearrange(hidden_states, 'B (X D) H W -> B D H (W X)', X=self.conv_patch) # B 4D H W/4 -> B D H W
|
490 |
+
sequence_output = self.reducer(hidden_states) # B,C,H,W -> B,C,H/conv_shape[1],W/(conv_shape[1])
|
491 |
+
sequence_output = sequence_output.flatten(2).transpose(1, 2) # B,C,H/conv_shape[1],W/(conv_shape[1]) -> B,C,L/conv_patch -> B,L/conv_patch,C
|
492 |
+
sequence_output = sequence_output.transpose(0, 1).contiguous() # L/conv_patch, B, C
|
493 |
+
## align visual features with language embedding with fc
|
494 |
+
sequence_output = self.visual_fc(sequence_output) # L/conv_patch, B, h
|
495 |
+
sequence_output = sequence_output.transpose(0, 1).contiguous() # B, s/4, h
|
496 |
+
sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(B, 1, 1)], dim=1)
|
497 |
+
|
498 |
+
return sequence_output
|
499 |
+
|
mplug_docowl/processor.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange, repeat
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image, ImageFile
|
5 |
+
import random
|
6 |
+
from torchvision.ops.boxes import box_area
|
7 |
+
|
8 |
+
from torchvision.transforms.transforms import InterpolationMode
|
9 |
+
from torchvision.transforms import functional as F
|
10 |
+
import numpy as np
|
11 |
+
from icecream import ic
|
12 |
+
|
13 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
14 |
+
ImageFile.MAX_IMAGE_PIXELS = None
|
15 |
+
Image.MAX_IMAGE_PIXELS = None
|
16 |
+
|
17 |
+
def box_iou(boxes1, area1, boxes2, eps=1e-5):
|
18 |
+
area2 = box_area(boxes2)
|
19 |
+
|
20 |
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
21 |
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
22 |
+
|
23 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
24 |
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
25 |
+
|
26 |
+
union = area1[:, None] + area2 - inter
|
27 |
+
|
28 |
+
iou = inter / (union+eps)
|
29 |
+
return iou, union
|
30 |
+
|
31 |
+
def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5):
|
32 |
+
# anchors x1 y1 x2 y2
|
33 |
+
|
34 |
+
# image_size: (h, w)
|
35 |
+
# xyxy
|
36 |
+
input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0)
|
37 |
+
|
38 |
+
boxes1 = anchors
|
39 |
+
boxes2 = input_image_bbox
|
40 |
+
boxes3 = anchors.clone()
|
41 |
+
# y2
|
42 |
+
boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou
|
43 |
+
|
44 |
+
area1 = anchors_areas
|
45 |
+
|
46 |
+
iou, _ = box_iou(boxes1, area1, boxes2)
|
47 |
+
iou = iou.squeeze(1)
|
48 |
+
shape_iou, _ = box_iou(boxes1, area1, boxes3)
|
49 |
+
shape_iou = shape_iou.diag()
|
50 |
+
# 优先匹配形状接近 再匹配分辨率接近
|
51 |
+
index = torch.argmax(shape_iou*100+iou,dim=0)
|
52 |
+
return index
|
53 |
+
|
54 |
+
class AnchorResize(torch.nn.Module):
|
55 |
+
|
56 |
+
def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None):
|
57 |
+
super().__init__()
|
58 |
+
# xyxy
|
59 |
+
self.anchors = torch.tensor(
|
60 |
+
[[0, 0, _[1]*image_size[1], _[0]*image_size[0]]
|
61 |
+
for _ in anchors], requires_grad=False
|
62 |
+
)
|
63 |
+
|
64 |
+
self.anchor_areas = box_area(self.anchors)
|
65 |
+
|
66 |
+
self.interpolation = interpolation
|
67 |
+
self.antialias = antialias
|
68 |
+
|
69 |
+
def forward(self, img, skip_resize=False):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
img (PIL Image or Tensor): Image to be scaled.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
PIL Image or Tensor: Rescaled image.
|
76 |
+
"""
|
77 |
+
selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
|
78 |
+
target_size = self.anchors[selected_anchor][2:].tolist() # w,h
|
79 |
+
if skip_resize:
|
80 |
+
# for debug
|
81 |
+
return selected_anchor
|
82 |
+
return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor
|
83 |
+
|
84 |
+
def __repr__(self) -> str:
|
85 |
+
detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})"
|
86 |
+
return f"{self.__class__.__name__}{detail}"
|
87 |
+
|
88 |
+
grid_dict = {
|
89 |
+
'grid_1':[
|
90 |
+
(1,1)],
|
91 |
+
'grid_4':[
|
92 |
+
(1,1),
|
93 |
+
(1,2),(2,1),
|
94 |
+
(1,3),(3,1),
|
95 |
+
(2,2),(1,4),(4,1)],
|
96 |
+
'grid_9':[
|
97 |
+
(1,1),
|
98 |
+
(1,2),(2,1),
|
99 |
+
(1,3),(3,1),
|
100 |
+
(2,2),(1,4),(4,1),
|
101 |
+
(1,5),(5,1),
|
102 |
+
(1,6),(6,1),(2,3),(3,2),
|
103 |
+
(1,7),(7,1),
|
104 |
+
(4,2),(2,4),(1,8),(8,1),
|
105 |
+
(3,3),(1,9),(9,1)],
|
106 |
+
'grid_3x3':[
|
107 |
+
(3,3)],
|
108 |
+
'grid_20':[
|
109 |
+
(1, 1),
|
110 |
+
(1, 2), (2, 1),
|
111 |
+
(1, 3), (3, 1), (1, 4), (2, 2), (4, 1),
|
112 |
+
(1, 5), (5, 1),
|
113 |
+
(1, 6), (2, 3), (3, 2), (6, 1),
|
114 |
+
(1, 7), (7, 1),
|
115 |
+
(1, 8), (2, 4), (4, 2), (8, 1),
|
116 |
+
(1, 9), (3, 3), (9, 1),
|
117 |
+
(1, 10), (2, 5), (5, 2), (10, 1),
|
118 |
+
(1, 11), (11, 1),
|
119 |
+
(2, 6), (3, 4), (4, 3), (6, 2),
|
120 |
+
(2, 7), (7, 2),
|
121 |
+
(3, 5), (5, 3),
|
122 |
+
(2, 8), (4, 4), (8, 2),
|
123 |
+
(2, 9), (3, 6), (6, 3), (9, 2),
|
124 |
+
(2, 10), (4, 5), (5, 4), (10, 2)]
|
125 |
+
}
|
126 |
+
|
127 |
+
class DocProcessor():
|
128 |
+
def __init__(self, image_size=224, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=False):
|
129 |
+
self.add_global_img = add_global_img
|
130 |
+
self.add_textual_crop_indicator = add_textual_crop_indicator
|
131 |
+
self.media_token= "<|image|>"
|
132 |
+
# h,w
|
133 |
+
if isinstance(image_size, int):
|
134 |
+
image_size = (image_size, image_size)
|
135 |
+
self.image_size = image_size
|
136 |
+
# h,w
|
137 |
+
anchors = grid_dict[anchors]
|
138 |
+
self.anchors = [tuple(_) for _ in anchors]
|
139 |
+
self.anchor_max = max([max(_) for _ in self.anchors])
|
140 |
+
# xywh -> xyxy
|
141 |
+
self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC)
|
142 |
+
self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC)
|
143 |
+
self.image_transform = transforms.Compose([
|
144 |
+
transforms.ToTensor(),
|
145 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
146 |
+
])
|
147 |
+
|
148 |
+
def _process_image(self, images):
|
149 |
+
new_images = []
|
150 |
+
new_patch_position = []
|
151 |
+
num_image_mult = []
|
152 |
+
for image in images:
|
153 |
+
if self.add_global_img:
|
154 |
+
nocut_image = self.image_transform(self.old_resizer(image)).unsqueeze(0)
|
155 |
+
|
156 |
+
image, selected_anchor = self.resizer(image)
|
157 |
+
image_input = self.image_transform(image) # h,w,3 -> 3,h,w
|
158 |
+
# rearrange(x,'B C (n1 h) (n2 w) -> (B n1 n2) C h w', n1=self.down_sample[0], n2=self.down_sample[1])
|
159 |
+
image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1])
|
160 |
+
|
161 |
+
if self.add_global_img:
|
162 |
+
image_input = torch.cat([nocut_image, image_input], dim=0)
|
163 |
+
|
164 |
+
anchor = self.anchors[selected_anchor] # w,h
|
165 |
+
ic(anchor)
|
166 |
+
patch_position = torch.cat([
|
167 |
+
repeat(torch.arange(anchor[0]), 'num_h -> num_h num_w 1', num_w=anchor[1]),
|
168 |
+
repeat(torch.arange(anchor[1]), 'num_w -> num_h num_w 1', num_h=anchor[0])],dim=2)
|
169 |
+
patch_position = rearrange(patch_position, 'num_h num_w p-> (num_h num_w) p', p=2) # num_patch, (ph,pw)
|
170 |
+
|
171 |
+
if self.add_global_img:
|
172 |
+
patch_position = torch.cat([torch.ones(1,2).long()*self.anchor_max, patch_position], dim=0)
|
173 |
+
|
174 |
+
new_images.append(image_input)
|
175 |
+
new_patch_position.append(patch_position)
|
176 |
+
num_image_mult.append(patch_position.shape[0])
|
177 |
+
|
178 |
+
new_images = torch.cat(new_images,dim=0)
|
179 |
+
new_patch_position = torch.cat(new_patch_position, dim=0)
|
180 |
+
return new_images, new_patch_position, num_image_mult
|
181 |
+
|
182 |
+
def __call__(self, images=None, query=None):
|
183 |
+
assert images is not None
|
184 |
+
|
185 |
+
if not isinstance(images, list):
|
186 |
+
images = [images]
|
187 |
+
image_pils = []
|
188 |
+
for image in images:
|
189 |
+
if isinstance(image, str):
|
190 |
+
image = Image.open(image).convert('RGB')
|
191 |
+
else:
|
192 |
+
image = image.convert('RGB')
|
193 |
+
# ic(image.size)
|
194 |
+
image_pils.append(image)
|
195 |
+
|
196 |
+
image_data, patch_position, num_image_mult = self._process_image(image_pils)
|
197 |
+
|
198 |
+
assert self.media_token in query
|
199 |
+
text_list = query.split(self.media_token)
|
200 |
+
text = text_list[0]
|
201 |
+
image_token_ptr = 0
|
202 |
+
for next_text in text_list[1:]:
|
203 |
+
if self.add_textual_crop_indicator:
|
204 |
+
# generate image placeholders with interleaved texutual crop indicator
|
205 |
+
# e.g. <global_img><|image|><crop_img_row0_col0><|image|><crop_img_row0_col1><|image|>...
|
206 |
+
for patch_pos in patch_position.tolist():
|
207 |
+
# global non-crop image
|
208 |
+
if patch_pos[0] == self.anchor_max and patch_pos[1] == self.anchor_max:
|
209 |
+
text += '<global_img><|image|>'
|
210 |
+
else:
|
211 |
+
row_col = 'row'+str(patch_pos[0])+'_col'+str(patch_pos[1])
|
212 |
+
text += '<crop_img_'+row_col+'><|image|>'
|
213 |
+
else:
|
214 |
+
# generate successive image placeholders for a image, 1 crop img == 1 <|image|>
|
215 |
+
text += '<|image|>'*num_image_mult[image_token_ptr]
|
216 |
+
text += next_text
|
217 |
+
image_token_ptr += 1
|
218 |
+
|
219 |
+
return image_data, patch_position, text
|
mplug_docowl/serve/__init__.py
ADDED
File without changes
|
mplug_docowl/serve/cli.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from mplug_owl2.conversation import conv_templates, SeparatorStyle
|
6 |
+
from mplug_owl2.model.builder import load_pretrained_model
|
7 |
+
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
|
17 |
+
def disable_torch_init():
|
18 |
+
"""
|
19 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
20 |
+
"""
|
21 |
+
import torch
|
22 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
23 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
24 |
+
|
25 |
+
|
26 |
+
def load_image(image_file):
|
27 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
28 |
+
response = requests.get(image_file)
|
29 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
30 |
+
else:
|
31 |
+
image = Image.open(image_file).convert('RGB')
|
32 |
+
return image
|
33 |
+
|
34 |
+
|
35 |
+
def main(args):
|
36 |
+
# Model
|
37 |
+
disable_torch_init()
|
38 |
+
|
39 |
+
model_name = get_model_name_from_path(args.model_path)
|
40 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
41 |
+
|
42 |
+
conv_mode = "mplug_owl2"
|
43 |
+
|
44 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
45 |
+
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
|
46 |
+
else:
|
47 |
+
args.conv_mode = conv_mode
|
48 |
+
|
49 |
+
conv = conv_templates[args.conv_mode].copy()
|
50 |
+
roles = conv.roles
|
51 |
+
|
52 |
+
image = load_image(args.image_file)
|
53 |
+
# Similar operation in model_worker.py
|
54 |
+
image_tensor = process_images([image], image_processor, args)
|
55 |
+
if type(image_tensor) is list:
|
56 |
+
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
57 |
+
else:
|
58 |
+
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
59 |
+
|
60 |
+
while True:
|
61 |
+
try:
|
62 |
+
inp = input(f"{roles[0]}: ")
|
63 |
+
except EOFError:
|
64 |
+
inp = ""
|
65 |
+
if not inp:
|
66 |
+
print("exit...")
|
67 |
+
break
|
68 |
+
|
69 |
+
print(f"{roles[1]}: ", end="")
|
70 |
+
|
71 |
+
if image is not None:
|
72 |
+
# first message
|
73 |
+
inp = DEFAULT_IMAGE_TOKEN + inp
|
74 |
+
conv.append_message(conv.roles[0], inp)
|
75 |
+
image = None
|
76 |
+
else:
|
77 |
+
# later messages
|
78 |
+
conv.append_message(conv.roles[0], inp)
|
79 |
+
conv.append_message(conv.roles[1], None)
|
80 |
+
prompt = conv.get_prompt()
|
81 |
+
|
82 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
83 |
+
stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
|
84 |
+
keywords = [stop_str]
|
85 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
86 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
87 |
+
|
88 |
+
with torch.inference_mode():
|
89 |
+
output_ids = model.generate(
|
90 |
+
input_ids,
|
91 |
+
images=image_tensor,
|
92 |
+
do_sample=True,
|
93 |
+
temperature=args.temperature,
|
94 |
+
max_new_tokens=args.max_new_tokens,
|
95 |
+
streamer=streamer,
|
96 |
+
use_cache=True,
|
97 |
+
stopping_criteria=[stopping_criteria])
|
98 |
+
|
99 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
100 |
+
conv.messages[-1][-1] = outputs
|
101 |
+
|
102 |
+
if args.debug:
|
103 |
+
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
109 |
+
parser.add_argument("--model-base", type=str, default=None)
|
110 |
+
parser.add_argument("--image-file", type=str, required=True)
|
111 |
+
parser.add_argument("--device", type=str, default="cuda")
|
112 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
113 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
114 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
115 |
+
parser.add_argument("--load-8bit", action="store_true")
|
116 |
+
parser.add_argument("--load-4bit", action="store_true")
|
117 |
+
parser.add_argument("--debug", action="store_true")
|
118 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
119 |
+
args = parser.parse_args()
|
120 |
+
main(args)
|
mplug_docowl/serve/controller.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import asyncio
|
7 |
+
import dataclasses
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import time
|
12 |
+
from typing import List, Union
|
13 |
+
import threading
|
14 |
+
|
15 |
+
from fastapi import FastAPI, Request
|
16 |
+
from fastapi.responses import StreamingResponse
|
17 |
+
import numpy as np
|
18 |
+
import requests
|
19 |
+
import uvicorn
|
20 |
+
|
21 |
+
from mplug_owl2.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
+
from mplug_owl2.utils import build_logger, server_error_msg
|
23 |
+
|
24 |
+
|
25 |
+
logger = build_logger("controller", "controller.log")
|
26 |
+
|
27 |
+
|
28 |
+
class DispatchMethod(Enum):
|
29 |
+
LOTTERY = auto()
|
30 |
+
SHORTEST_QUEUE = auto()
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_str(cls, name):
|
34 |
+
if name == "lottery":
|
35 |
+
return cls.LOTTERY
|
36 |
+
elif name == "shortest_queue":
|
37 |
+
return cls.SHORTEST_QUEUE
|
38 |
+
else:
|
39 |
+
raise ValueError(f"Invalid dispatch method")
|
40 |
+
|
41 |
+
|
42 |
+
@dataclasses.dataclass
|
43 |
+
class WorkerInfo:
|
44 |
+
model_names: List[str]
|
45 |
+
speed: int
|
46 |
+
queue_length: int
|
47 |
+
check_heart_beat: bool
|
48 |
+
last_heart_beat: str
|
49 |
+
|
50 |
+
|
51 |
+
def heart_beat_controller(controller):
|
52 |
+
while True:
|
53 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
54 |
+
controller.remove_stable_workers_by_expiration()
|
55 |
+
|
56 |
+
|
57 |
+
class Controller:
|
58 |
+
def __init__(self, dispatch_method: str):
|
59 |
+
# Dict[str -> WorkerInfo]
|
60 |
+
self.worker_info = {}
|
61 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
62 |
+
|
63 |
+
self.heart_beat_thread = threading.Thread(
|
64 |
+
target=heart_beat_controller, args=(self,))
|
65 |
+
self.heart_beat_thread.start()
|
66 |
+
|
67 |
+
logger.info("Init controller")
|
68 |
+
|
69 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
70 |
+
worker_status: dict):
|
71 |
+
if worker_name not in self.worker_info:
|
72 |
+
logger.info(f"Register a new worker: {worker_name}")
|
73 |
+
else:
|
74 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
75 |
+
|
76 |
+
if not worker_status:
|
77 |
+
worker_status = self.get_worker_status(worker_name)
|
78 |
+
if not worker_status:
|
79 |
+
return False
|
80 |
+
|
81 |
+
self.worker_info[worker_name] = WorkerInfo(
|
82 |
+
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
83 |
+
check_heart_beat, time.time())
|
84 |
+
|
85 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
86 |
+
return True
|
87 |
+
|
88 |
+
def get_worker_status(self, worker_name: str):
|
89 |
+
try:
|
90 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
91 |
+
except requests.exceptions.RequestException as e:
|
92 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
93 |
+
return None
|
94 |
+
|
95 |
+
if r.status_code != 200:
|
96 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
97 |
+
return None
|
98 |
+
|
99 |
+
return r.json()
|
100 |
+
|
101 |
+
def remove_worker(self, worker_name: str):
|
102 |
+
del self.worker_info[worker_name]
|
103 |
+
|
104 |
+
def refresh_all_workers(self):
|
105 |
+
old_info = dict(self.worker_info)
|
106 |
+
self.worker_info = {}
|
107 |
+
|
108 |
+
for w_name, w_info in old_info.items():
|
109 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
110 |
+
logger.info(f"Remove stale worker: {w_name}")
|
111 |
+
|
112 |
+
def list_models(self):
|
113 |
+
model_names = set()
|
114 |
+
|
115 |
+
for w_name, w_info in self.worker_info.items():
|
116 |
+
model_names.update(w_info.model_names)
|
117 |
+
|
118 |
+
return list(model_names)
|
119 |
+
|
120 |
+
def get_worker_address(self, model_name: str):
|
121 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
122 |
+
worker_names = []
|
123 |
+
worker_speeds = []
|
124 |
+
for w_name, w_info in self.worker_info.items():
|
125 |
+
if model_name in w_info.model_names:
|
126 |
+
worker_names.append(w_name)
|
127 |
+
worker_speeds.append(w_info.speed)
|
128 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
129 |
+
norm = np.sum(worker_speeds)
|
130 |
+
if norm < 1e-4:
|
131 |
+
return ""
|
132 |
+
worker_speeds = worker_speeds / norm
|
133 |
+
if True: # Directly return address
|
134 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
135 |
+
p=worker_speeds)
|
136 |
+
worker_name = worker_names[pt]
|
137 |
+
return worker_name
|
138 |
+
|
139 |
+
# Check status before returning
|
140 |
+
while True:
|
141 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
142 |
+
p=worker_speeds)
|
143 |
+
worker_name = worker_names[pt]
|
144 |
+
|
145 |
+
if self.get_worker_status(worker_name):
|
146 |
+
break
|
147 |
+
else:
|
148 |
+
self.remove_worker(worker_name)
|
149 |
+
worker_speeds[pt] = 0
|
150 |
+
norm = np.sum(worker_speeds)
|
151 |
+
if norm < 1e-4:
|
152 |
+
return ""
|
153 |
+
worker_speeds = worker_speeds / norm
|
154 |
+
continue
|
155 |
+
return worker_name
|
156 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
157 |
+
worker_names = []
|
158 |
+
worker_qlen = []
|
159 |
+
for w_name, w_info in self.worker_info.items():
|
160 |
+
if model_name in w_info.model_names:
|
161 |
+
worker_names.append(w_name)
|
162 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
163 |
+
if len(worker_names) == 0:
|
164 |
+
return ""
|
165 |
+
min_index = np.argmin(worker_qlen)
|
166 |
+
w_name = worker_names[min_index]
|
167 |
+
self.worker_info[w_name].queue_length += 1
|
168 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
169 |
+
return w_name
|
170 |
+
else:
|
171 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
172 |
+
|
173 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
174 |
+
if worker_name not in self.worker_info:
|
175 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
176 |
+
return False
|
177 |
+
|
178 |
+
self.worker_info[worker_name].queue_length = queue_length
|
179 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
180 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
181 |
+
return True
|
182 |
+
|
183 |
+
def remove_stable_workers_by_expiration(self):
|
184 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
185 |
+
to_delete = []
|
186 |
+
for worker_name, w_info in self.worker_info.items():
|
187 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
188 |
+
to_delete.append(worker_name)
|
189 |
+
|
190 |
+
for worker_name in to_delete:
|
191 |
+
self.remove_worker(worker_name)
|
192 |
+
|
193 |
+
def worker_api_generate_stream(self, params):
|
194 |
+
worker_addr = self.get_worker_address(params["model"])
|
195 |
+
if not worker_addr:
|
196 |
+
logger.info(f"no worker: {params['model']}")
|
197 |
+
ret = {
|
198 |
+
"text": server_error_msg,
|
199 |
+
"error_code": 2,
|
200 |
+
}
|
201 |
+
yield json.dumps(ret).encode() + b"\0"
|
202 |
+
|
203 |
+
try:
|
204 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
205 |
+
json=params, stream=True, timeout=5)
|
206 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
207 |
+
if chunk:
|
208 |
+
yield chunk + b"\0"
|
209 |
+
except requests.exceptions.RequestException as e:
|
210 |
+
logger.info(f"worker timeout: {worker_addr}")
|
211 |
+
ret = {
|
212 |
+
"text": server_error_msg,
|
213 |
+
"error_code": 3,
|
214 |
+
}
|
215 |
+
yield json.dumps(ret).encode() + b"\0"
|
216 |
+
|
217 |
+
|
218 |
+
# Let the controller act as a worker to achieve hierarchical
|
219 |
+
# management. This can be used to connect isolated sub networks.
|
220 |
+
def worker_api_get_status(self):
|
221 |
+
model_names = set()
|
222 |
+
speed = 0
|
223 |
+
queue_length = 0
|
224 |
+
|
225 |
+
for w_name in self.worker_info:
|
226 |
+
worker_status = self.get_worker_status(w_name)
|
227 |
+
if worker_status is not None:
|
228 |
+
model_names.update(worker_status["model_names"])
|
229 |
+
speed += worker_status["speed"]
|
230 |
+
queue_length += worker_status["queue_length"]
|
231 |
+
|
232 |
+
return {
|
233 |
+
"model_names": list(model_names),
|
234 |
+
"speed": speed,
|
235 |
+
"queue_length": queue_length,
|
236 |
+
}
|
237 |
+
|
238 |
+
|
239 |
+
app = FastAPI()
|
240 |
+
|
241 |
+
|
242 |
+
@app.post("/register_worker")
|
243 |
+
async def register_worker(request: Request):
|
244 |
+
data = await request.json()
|
245 |
+
controller.register_worker(
|
246 |
+
data["worker_name"], data["check_heart_beat"],
|
247 |
+
data.get("worker_status", None))
|
248 |
+
|
249 |
+
|
250 |
+
@app.post("/refresh_all_workers")
|
251 |
+
async def refresh_all_workers():
|
252 |
+
models = controller.refresh_all_workers()
|
253 |
+
|
254 |
+
|
255 |
+
@app.post("/list_models")
|
256 |
+
async def list_models():
|
257 |
+
models = controller.list_models()
|
258 |
+
return {"models": models}
|
259 |
+
|
260 |
+
|
261 |
+
@app.post("/get_worker_address")
|
262 |
+
async def get_worker_address(request: Request):
|
263 |
+
data = await request.json()
|
264 |
+
addr = controller.get_worker_address(data["model"])
|
265 |
+
return {"address": addr}
|
266 |
+
|
267 |
+
|
268 |
+
@app.post("/receive_heart_beat")
|
269 |
+
async def receive_heart_beat(request: Request):
|
270 |
+
data = await request.json()
|
271 |
+
exist = controller.receive_heart_beat(
|
272 |
+
data["worker_name"], data["queue_length"])
|
273 |
+
return {"exist": exist}
|
274 |
+
|
275 |
+
|
276 |
+
@app.post("/worker_generate_stream")
|
277 |
+
async def worker_api_generate_stream(request: Request):
|
278 |
+
params = await request.json()
|
279 |
+
generator = controller.worker_api_generate_stream(params)
|
280 |
+
return StreamingResponse(generator)
|
281 |
+
|
282 |
+
|
283 |
+
@app.post("/worker_get_status")
|
284 |
+
async def worker_api_get_status(request: Request):
|
285 |
+
return controller.worker_api_get_status()
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
parser = argparse.ArgumentParser()
|
290 |
+
parser.add_argument("--host", type=str, default="localhost")
|
291 |
+
parser.add_argument("--port", type=int, default=21001)
|
292 |
+
parser.add_argument("--dispatch-method", type=str, choices=[
|
293 |
+
"lottery", "shortest_queue"], default="shortest_queue")
|
294 |
+
args = parser.parse_args()
|
295 |
+
logger.info(f"args: {args}")
|
296 |
+
|
297 |
+
controller = Controller(args.dispatch_method)
|
298 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
mplug_docowl/serve/examples/Rebecca_(1939_poster)_Small.jpeg
ADDED
mplug_docowl/serve/examples/extreme_ironing.jpg
ADDED
mplug_docowl/serve/gradio_web_server.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from mplug_owl2.conversation import (default_conversation, conv_templates,
|
11 |
+
SeparatorStyle)
|
12 |
+
from mplug_owl2.constants import LOGDIR
|
13 |
+
from mplug_owl2.utils import (build_logger, server_error_msg,
|
14 |
+
violates_moderation, moderation_msg)
|
15 |
+
import hashlib
|
16 |
+
|
17 |
+
|
18 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
19 |
+
|
20 |
+
headers = {"User-Agent": "mPLUG-Owl2 Client"}
|
21 |
+
|
22 |
+
no_change_btn = gr.Button.update()
|
23 |
+
enable_btn = gr.Button.update(interactive=True)
|
24 |
+
disable_btn = gr.Button.update(interactive=False)
|
25 |
+
|
26 |
+
priority = {
|
27 |
+
"vicuna-13b": "aaaaaaa",
|
28 |
+
"koala-13b": "aaaaaab",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def get_conv_log_filename():
|
33 |
+
t = datetime.datetime.now()
|
34 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
35 |
+
return name
|
36 |
+
|
37 |
+
|
38 |
+
def get_model_list():
|
39 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
40 |
+
assert ret.status_code == 200
|
41 |
+
ret = requests.post(args.controller_url + "/list_models")
|
42 |
+
models = ret.json()["models"]
|
43 |
+
models.sort(key=lambda x: priority.get(x, x))
|
44 |
+
logger.info(f"Models: {models}")
|
45 |
+
return models
|
46 |
+
|
47 |
+
|
48 |
+
get_window_url_params = """
|
49 |
+
function() {
|
50 |
+
const params = new URLSearchParams(window.location.search);
|
51 |
+
url_params = Object.fromEntries(params);
|
52 |
+
console.log(url_params);
|
53 |
+
return url_params;
|
54 |
+
}
|
55 |
+
"""
|
56 |
+
|
57 |
+
|
58 |
+
def load_demo(url_params, request: gr.Request):
|
59 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
60 |
+
|
61 |
+
dropdown_update = gr.Dropdown.update(visible=True)
|
62 |
+
if "model" in url_params:
|
63 |
+
model = url_params["model"]
|
64 |
+
if model in models:
|
65 |
+
dropdown_update = gr.Dropdown.update(
|
66 |
+
value=model, visible=True)
|
67 |
+
|
68 |
+
state = default_conversation.copy()
|
69 |
+
return state, dropdown_update
|
70 |
+
|
71 |
+
|
72 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
73 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
74 |
+
models = get_model_list()
|
75 |
+
state = default_conversation.copy()
|
76 |
+
dropdown_update = gr.Dropdown.update(
|
77 |
+
choices=models,
|
78 |
+
value=models[0] if len(models) > 0 else ""
|
79 |
+
)
|
80 |
+
return state, dropdown_update
|
81 |
+
|
82 |
+
|
83 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
84 |
+
with open(get_conv_log_filename(), "a") as fout:
|
85 |
+
data = {
|
86 |
+
"tstamp": round(time.time(), 4),
|
87 |
+
"type": vote_type,
|
88 |
+
"model": model_selector,
|
89 |
+
"state": state.dict(),
|
90 |
+
"ip": request.client.host,
|
91 |
+
}
|
92 |
+
fout.write(json.dumps(data) + "\n")
|
93 |
+
|
94 |
+
|
95 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
96 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
97 |
+
vote_last_response(state, "upvote", model_selector, request)
|
98 |
+
return ("",) + (disable_btn,) * 3
|
99 |
+
|
100 |
+
|
101 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
102 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
103 |
+
vote_last_response(state, "downvote", model_selector, request)
|
104 |
+
return ("",) + (disable_btn,) * 3
|
105 |
+
|
106 |
+
|
107 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
108 |
+
logger.info(f"flag. ip: {request.client.host}")
|
109 |
+
vote_last_response(state, "flag", model_selector, request)
|
110 |
+
return ("",) + (disable_btn,) * 3
|
111 |
+
|
112 |
+
|
113 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
114 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
115 |
+
state.messages[-1][-1] = None
|
116 |
+
prev_human_msg = state.messages[-2]
|
117 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
118 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
119 |
+
state.skip_next = False
|
120 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
121 |
+
|
122 |
+
|
123 |
+
def clear_history(request: gr.Request):
|
124 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
125 |
+
state = default_conversation.copy()
|
126 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
127 |
+
|
128 |
+
|
129 |
+
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
130 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
131 |
+
if len(text) <= 0 and image is None:
|
132 |
+
state.skip_next = True
|
133 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
134 |
+
if args.moderate:
|
135 |
+
flagged = violates_moderation(text)
|
136 |
+
if flagged:
|
137 |
+
state.skip_next = True
|
138 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
139 |
+
no_change_btn,) * 5
|
140 |
+
|
141 |
+
text = text[:1536] # Hard cut-off
|
142 |
+
if image is not None:
|
143 |
+
text = text[:1200] # Hard cut-off for images
|
144 |
+
if '<|image|>' not in text:
|
145 |
+
# text = text + '<|image|>'
|
146 |
+
text = '<|image|>' + text
|
147 |
+
text = (text, image, image_process_mode)
|
148 |
+
if len(state.get_images(return_pil=True)) > 0:
|
149 |
+
state = default_conversation.copy()
|
150 |
+
state.append_message(state.roles[0], text)
|
151 |
+
state.append_message(state.roles[1], None)
|
152 |
+
state.skip_next = False
|
153 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
154 |
+
|
155 |
+
|
156 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
157 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
158 |
+
start_tstamp = time.time()
|
159 |
+
model_name = model_selector
|
160 |
+
|
161 |
+
if state.skip_next:
|
162 |
+
# This generate call is skipped due to invalid inputs
|
163 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
164 |
+
return
|
165 |
+
|
166 |
+
if len(state.messages) == state.offset + 2:
|
167 |
+
# First round of conversation
|
168 |
+
template_name = "mplug_owl2"
|
169 |
+
new_state = conv_templates[template_name].copy()
|
170 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
171 |
+
new_state.append_message(new_state.roles[1], None)
|
172 |
+
state = new_state
|
173 |
+
|
174 |
+
# Query worker address
|
175 |
+
controller_url = args.controller_url
|
176 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
177 |
+
json={"model": model_name})
|
178 |
+
worker_addr = ret.json()["address"]
|
179 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
180 |
+
|
181 |
+
# No available worker
|
182 |
+
if worker_addr == "":
|
183 |
+
state.messages[-1][-1] = server_error_msg
|
184 |
+
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
185 |
+
return
|
186 |
+
|
187 |
+
# Construct prompt
|
188 |
+
prompt = state.get_prompt()
|
189 |
+
|
190 |
+
all_images = state.get_images(return_pil=True)
|
191 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
192 |
+
for image, hash in zip(all_images, all_image_hash):
|
193 |
+
t = datetime.datetime.now()
|
194 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
195 |
+
if not os.path.isfile(filename):
|
196 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
197 |
+
image.save(filename)
|
198 |
+
|
199 |
+
# Make requests
|
200 |
+
pload = {
|
201 |
+
"model": model_name,
|
202 |
+
"prompt": prompt,
|
203 |
+
"temperature": float(temperature),
|
204 |
+
"top_p": float(top_p),
|
205 |
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
206 |
+
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
207 |
+
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
208 |
+
}
|
209 |
+
logger.info(f"==== request ====\n{pload}")
|
210 |
+
|
211 |
+
pload['images'] = state.get_images()
|
212 |
+
|
213 |
+
state.messages[-1][-1] = "▌"
|
214 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
215 |
+
|
216 |
+
try:
|
217 |
+
# Stream output
|
218 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
219 |
+
headers=headers, json=pload, stream=True, timeout=10)
|
220 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
221 |
+
if chunk:
|
222 |
+
data = json.loads(chunk.decode())
|
223 |
+
if data["error_code"] == 0:
|
224 |
+
output = data["text"][len(prompt):].strip()
|
225 |
+
state.messages[-1][-1] = output + "▌"
|
226 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
227 |
+
else:
|
228 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
229 |
+
state.messages[-1][-1] = output
|
230 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
231 |
+
return
|
232 |
+
time.sleep(0.03)
|
233 |
+
except requests.exceptions.RequestException as e:
|
234 |
+
state.messages[-1][-1] = server_error_msg
|
235 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
236 |
+
return
|
237 |
+
|
238 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
239 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
240 |
+
|
241 |
+
finish_tstamp = time.time()
|
242 |
+
logger.info(f"{output}")
|
243 |
+
|
244 |
+
with open(get_conv_log_filename(), "a") as fout:
|
245 |
+
data = {
|
246 |
+
"tstamp": round(finish_tstamp, 4),
|
247 |
+
"type": "chat",
|
248 |
+
"model": model_name,
|
249 |
+
"start": round(start_tstamp, 4),
|
250 |
+
"finish": round(start_tstamp, 4),
|
251 |
+
"state": state.dict(),
|
252 |
+
"images": all_image_hash,
|
253 |
+
"ip": request.client.host,
|
254 |
+
}
|
255 |
+
fout.write(json.dumps(data) + "\n")
|
256 |
+
|
257 |
+
|
258 |
+
title_markdown = ("""
|
259 |
+
<h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
|
260 |
+
|
261 |
+
<h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
|
262 |
+
|
263 |
+
<h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
|
264 |
+
|
265 |
+
<div align="center">
|
266 |
+
<div style="display:flex; gap: 0.25rem;" align="center">
|
267 |
+
<a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
268 |
+
<a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
|
269 |
+
<a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
|
270 |
+
</div>
|
271 |
+
</div>
|
272 |
+
|
273 |
+
""")
|
274 |
+
|
275 |
+
|
276 |
+
tos_markdown = ("""
|
277 |
+
### Terms of use
|
278 |
+
By using this service, users are required to agree to the following terms:
|
279 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
280 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
281 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
282 |
+
""")
|
283 |
+
|
284 |
+
|
285 |
+
learn_more_markdown = ("""
|
286 |
+
### License
|
287 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
288 |
+
""")
|
289 |
+
|
290 |
+
block_css = """
|
291 |
+
|
292 |
+
#buttons button {
|
293 |
+
min-width: min(120px,100%);
|
294 |
+
}
|
295 |
+
|
296 |
+
"""
|
297 |
+
|
298 |
+
def build_demo(embed_mode):
|
299 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
300 |
+
with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
|
301 |
+
state = gr.State()
|
302 |
+
|
303 |
+
if not embed_mode:
|
304 |
+
gr.Markdown(title_markdown)
|
305 |
+
|
306 |
+
with gr.Row():
|
307 |
+
with gr.Column(scale=3):
|
308 |
+
with gr.Row(elem_id="model_selector_row"):
|
309 |
+
model_selector = gr.Dropdown(
|
310 |
+
choices=models,
|
311 |
+
value=models[0] if len(models) > 0 else "",
|
312 |
+
interactive=True,
|
313 |
+
show_label=False,
|
314 |
+
container=False)
|
315 |
+
|
316 |
+
imagebox = gr.Image(type="pil")
|
317 |
+
image_process_mode = gr.Radio(
|
318 |
+
["Crop", "Resize", "Pad", "Default"],
|
319 |
+
value="Default",
|
320 |
+
label="Preprocess for non-square image", visible=False)
|
321 |
+
|
322 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
323 |
+
gr.Examples(examples=[
|
324 |
+
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
325 |
+
[f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
|
326 |
+
], inputs=[imagebox, textbox])
|
327 |
+
|
328 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
329 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
330 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
331 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
332 |
+
|
333 |
+
with gr.Column(scale=8):
|
334 |
+
chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
|
335 |
+
with gr.Row():
|
336 |
+
with gr.Column(scale=8):
|
337 |
+
textbox.render()
|
338 |
+
with gr.Column(scale=1, min_width=50):
|
339 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
340 |
+
with gr.Row(elem_id="buttons") as button_row:
|
341 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
342 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
343 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
344 |
+
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
345 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
346 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
347 |
+
|
348 |
+
if not embed_mode:
|
349 |
+
gr.Markdown(tos_markdown)
|
350 |
+
gr.Markdown(learn_more_markdown)
|
351 |
+
url_params = gr.JSON(visible=False)
|
352 |
+
|
353 |
+
# Register listeners
|
354 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
355 |
+
upvote_btn.click(
|
356 |
+
upvote_last_response,
|
357 |
+
[state, model_selector],
|
358 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
359 |
+
queue=False
|
360 |
+
)
|
361 |
+
downvote_btn.click(
|
362 |
+
downvote_last_response,
|
363 |
+
[state, model_selector],
|
364 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
365 |
+
queue=False
|
366 |
+
)
|
367 |
+
flag_btn.click(
|
368 |
+
flag_last_response,
|
369 |
+
[state, model_selector],
|
370 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
371 |
+
queue=False
|
372 |
+
)
|
373 |
+
|
374 |
+
regenerate_btn.click(
|
375 |
+
regenerate,
|
376 |
+
[state, image_process_mode],
|
377 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
378 |
+
queue=False
|
379 |
+
).then(
|
380 |
+
http_bot,
|
381 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
382 |
+
[state, chatbot] + btn_list
|
383 |
+
)
|
384 |
+
|
385 |
+
clear_btn.click(
|
386 |
+
clear_history,
|
387 |
+
None,
|
388 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
389 |
+
queue=False
|
390 |
+
)
|
391 |
+
|
392 |
+
textbox.submit(
|
393 |
+
add_text,
|
394 |
+
[state, textbox, imagebox, image_process_mode],
|
395 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
396 |
+
queue=False
|
397 |
+
).then(
|
398 |
+
http_bot,
|
399 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
400 |
+
[state, chatbot] + btn_list
|
401 |
+
)
|
402 |
+
|
403 |
+
submit_btn.click(
|
404 |
+
add_text,
|
405 |
+
[state, textbox, imagebox, image_process_mode],
|
406 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
407 |
+
queue=False
|
408 |
+
).then(
|
409 |
+
http_bot,
|
410 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
411 |
+
[state, chatbot] + btn_list
|
412 |
+
)
|
413 |
+
|
414 |
+
if args.model_list_mode == "once":
|
415 |
+
demo.load(
|
416 |
+
load_demo,
|
417 |
+
[url_params],
|
418 |
+
[state, model_selector],
|
419 |
+
_js=get_window_url_params,
|
420 |
+
queue=False
|
421 |
+
)
|
422 |
+
elif args.model_list_mode == "reload":
|
423 |
+
demo.load(
|
424 |
+
load_demo_refresh_model_list,
|
425 |
+
None,
|
426 |
+
[state, model_selector],
|
427 |
+
queue=False
|
428 |
+
)
|
429 |
+
else:
|
430 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
431 |
+
|
432 |
+
return demo
|
433 |
+
|
434 |
+
|
435 |
+
if __name__ == "__main__":
|
436 |
+
parser = argparse.ArgumentParser()
|
437 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
438 |
+
parser.add_argument("--port", type=int)
|
439 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
440 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
441 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
442 |
+
choices=["once", "reload"])
|
443 |
+
parser.add_argument("--share", action="store_true")
|
444 |
+
parser.add_argument("--moderate", action="store_true")
|
445 |
+
parser.add_argument("--embed", action="store_true")
|
446 |
+
args = parser.parse_args()
|
447 |
+
logger.info(f"args: {args}")
|
448 |
+
|
449 |
+
models = get_model_list()
|
450 |
+
|
451 |
+
logger.info(args)
|
452 |
+
demo = build_demo(args.embed)
|
453 |
+
demo.queue(
|
454 |
+
concurrency_count=args.concurrency_count,
|
455 |
+
api_open=False
|
456 |
+
).launch(
|
457 |
+
server_name=args.host,
|
458 |
+
server_port=args.port,
|
459 |
+
share=False
|
460 |
+
)
|
mplug_docowl/serve/model_worker.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
import uuid
|
10 |
+
|
11 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
12 |
+
from fastapi.responses import StreamingResponse
|
13 |
+
import requests
|
14 |
+
import torch
|
15 |
+
import uvicorn
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
from mplug_docowl.utils import (build_logger, server_error_msg,
|
19 |
+
pretty_print_semaphore)
|
20 |
+
|
21 |
+
from mplug_docowl.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,WORKER_HEART_BEAT_INTERVAL
|
22 |
+
from mplug_docowl.conversation import conv_templates, SeparatorStyle
|
23 |
+
from mplug_docowl.model.builder import load_pretrained_model
|
24 |
+
from mplug_docowl.mm_utils import load_image_from_base64, process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
25 |
+
from mplug_docowl.processor import DocProcessor
|
26 |
+
|
27 |
+
|
28 |
+
from transformers import TextIteratorStreamer
|
29 |
+
from threading import Thread
|
30 |
+
|
31 |
+
|
32 |
+
GB = 1 << 30
|
33 |
+
|
34 |
+
worker_id = str(uuid.uuid4())[:6]
|
35 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
36 |
+
global_counter = 0
|
37 |
+
|
38 |
+
model_semaphore = None
|
39 |
+
|
40 |
+
|
41 |
+
def heart_beat_worker(controller):
|
42 |
+
|
43 |
+
while True:
|
44 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
45 |
+
controller.send_heart_beat()
|
46 |
+
|
47 |
+
|
48 |
+
class DocOwlInfer():
|
49 |
+
def __init__(self, ckpt_path, anchors='grid_9', add_global_img=True, load_8bit=False, load_4bit=False):
|
50 |
+
model_name = get_model_name_from_path(ckpt_path)
|
51 |
+
ic(model_name)
|
52 |
+
self.tokenizer, self.model, _, _ = load_pretrained_model(ckpt_path, None, model_name, load_8bit=load_8bit, load_4bit=load_4bit, device="cuda")
|
53 |
+
self.doc_image_processor = DocProcessor(image_size=448, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
|
54 |
+
self.streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
55 |
+
|
56 |
+
def inference(self, image, query):
|
57 |
+
image_tensor, patch_positions, text = self.doc_image_processor(images=image, query='<|image|>'+query)
|
58 |
+
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
|
59 |
+
patch_positions = patch_positions.to(self.model.device)
|
60 |
+
|
61 |
+
# ic(image_tensor.shape, patch_positions.shape, text)
|
62 |
+
|
63 |
+
conv = conv_templates["mplug_owl2"].copy()
|
64 |
+
roles = conv.roles # ("USER", "ASSISTANT")
|
65 |
+
|
66 |
+
conv.append_message(conv.roles[0], text)
|
67 |
+
conv.append_message(conv.roles[1], None)
|
68 |
+
prompt = conv.get_prompt()
|
69 |
+
|
70 |
+
# ic(prompt)
|
71 |
+
|
72 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
|
73 |
+
|
74 |
+
# ic(input_ids)
|
75 |
+
|
76 |
+
stop_str = conv.sep2
|
77 |
+
keywords = [stop_str]
|
78 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
79 |
+
|
80 |
+
with torch.inference_mode():
|
81 |
+
output_ids = self.model.generate(
|
82 |
+
input_ids,
|
83 |
+
images=image_tensor,
|
84 |
+
patch_positions=patch_positions,
|
85 |
+
do_sample=False,
|
86 |
+
temperature=1.0,
|
87 |
+
max_new_tokens=512,
|
88 |
+
streamer=self.streamer,
|
89 |
+
use_cache=True,
|
90 |
+
stopping_criteria=[stopping_criteria])
|
91 |
+
|
92 |
+
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
93 |
+
|
94 |
+
return outputs.replace('</s>', '')
|
95 |
+
|
96 |
+
# TODO: adapt for docowl infer
|
97 |
+
class ModelWorker:
|
98 |
+
def __init__(self, controller_addr, worker_addr,
|
99 |
+
worker_id, no_register,
|
100 |
+
model_path, model_base, model_name,
|
101 |
+
resolution, anchors, add_global_img,
|
102 |
+
load_8bit, load_4bit, device):
|
103 |
+
self.controller_addr = controller_addr
|
104 |
+
self.worker_addr = worker_addr
|
105 |
+
self.worker_id = worker_id
|
106 |
+
if model_path.endswith("/"):
|
107 |
+
model_path = model_path[:-1]
|
108 |
+
|
109 |
+
self.model_name = get_model_name_from_path(ckpt_path)
|
110 |
+
|
111 |
+
self.device = device
|
112 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
113 |
+
|
114 |
+
self.tokenizer, self.model, _, self.context_len = load_pretrained_model(
|
115 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
116 |
+
|
117 |
+
self.resolution=resolution
|
118 |
+
self.token_num_each_img = (self.resolution/14)*(self.resolution/14)/self.model.get_model().vison2text.conv_patch
|
119 |
+
self.doc_image_processor = DocProcessor(image_size=resolution, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
|
120 |
+
|
121 |
+
|
122 |
+
self.is_multimodal = True
|
123 |
+
|
124 |
+
if not no_register:
|
125 |
+
self.register_to_controller()
|
126 |
+
self.heart_beat_thread = threading.Thread(
|
127 |
+
target=heart_beat_worker, args=(self,))
|
128 |
+
self.heart_beat_thread.start()
|
129 |
+
|
130 |
+
def register_to_controller(self):
|
131 |
+
logger.info("Register to controller")
|
132 |
+
|
133 |
+
url = self.controller_addr + "/register_worker"
|
134 |
+
data = {
|
135 |
+
"worker_name": self.worker_addr,
|
136 |
+
"check_heart_beat": True,
|
137 |
+
"worker_status": self.get_status()
|
138 |
+
}
|
139 |
+
r = requests.post(url, json=data)
|
140 |
+
assert r.status_code == 200
|
141 |
+
|
142 |
+
def send_heart_beat(self):
|
143 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
144 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
145 |
+
f"global_counter: {global_counter}")
|
146 |
+
|
147 |
+
url = self.controller_addr + "/receive_heart_beat"
|
148 |
+
|
149 |
+
while True:
|
150 |
+
try:
|
151 |
+
ret = requests.post(url, json={
|
152 |
+
"worker_name": self.worker_addr,
|
153 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
154 |
+
exist = ret.json()["exist"]
|
155 |
+
break
|
156 |
+
except requests.exceptions.RequestException as e:
|
157 |
+
logger.error(f"heart beat error: {e}")
|
158 |
+
time.sleep(5)
|
159 |
+
|
160 |
+
if not exist:
|
161 |
+
self.register_to_controller()
|
162 |
+
|
163 |
+
def get_queue_length(self):
|
164 |
+
if model_semaphore is None:
|
165 |
+
return 0
|
166 |
+
else:
|
167 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
168 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
169 |
+
|
170 |
+
def get_status(self):
|
171 |
+
return {
|
172 |
+
"model_names": [self.model_name],
|
173 |
+
"speed": 1,
|
174 |
+
"queue_length": self.get_queue_length(),
|
175 |
+
}
|
176 |
+
|
177 |
+
@torch.inference_mode()
|
178 |
+
def generate_stream(self, params):
|
179 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
180 |
+
|
181 |
+
prompt = params["prompt"]
|
182 |
+
ori_prompt = prompt
|
183 |
+
images = params.get("images", None)
|
184 |
+
num_image_tokens = 0
|
185 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
186 |
+
if len(images) > 0:
|
187 |
+
|
188 |
+
"""if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
189 |
+
raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
|
190 |
+
|
191 |
+
images = [load_image_from_base64(image) for image in images]
|
192 |
+
images = process_images(images, image_processor, model.config)
|
193 |
+
|
194 |
+
if type(images) is list:
|
195 |
+
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
196 |
+
else:
|
197 |
+
images = images.to(self.model.device, dtype=torch.float16)"""
|
198 |
+
|
199 |
+
# docowl only support 1 image, so only keep the last image
|
200 |
+
image = images[-1]
|
201 |
+
assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
|
202 |
+
|
203 |
+
image_tensor, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
|
204 |
+
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
|
205 |
+
patch_positions = patch_positions.to(self.model.device)
|
206 |
+
|
207 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
208 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
209 |
+
num_image_tokens = prompt.count(replace_token) * (self.token_num_each_img+1)
|
210 |
+
else:
|
211 |
+
images = None
|
212 |
+
patch_positions = None
|
213 |
+
image_args = {"images": images, "patch_positions":patch_positions}
|
214 |
+
else:
|
215 |
+
images = None
|
216 |
+
image_args = {}
|
217 |
+
|
218 |
+
temperature = float(params.get("temperature", 1.0))
|
219 |
+
top_p = float(params.get("top_p", 1.0))
|
220 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
|
221 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
222 |
+
stop_str = params.get("stop", None)
|
223 |
+
do_sample = True if temperature > 0.001 else False
|
224 |
+
|
225 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
226 |
+
keywords = [stop_str]
|
227 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
228 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
229 |
+
|
230 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
231 |
+
|
232 |
+
if max_new_tokens < 1:
|
233 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
234 |
+
return
|
235 |
+
|
236 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
237 |
+
inputs=input_ids,
|
238 |
+
do_sample=do_sample,
|
239 |
+
temperature=temperature,
|
240 |
+
top_p=top_p,
|
241 |
+
max_new_tokens=max_new_tokens,
|
242 |
+
streamer=streamer,
|
243 |
+
stopping_criteria=[stopping_criteria],
|
244 |
+
use_cache=True,
|
245 |
+
**image_args
|
246 |
+
))
|
247 |
+
thread.start()
|
248 |
+
|
249 |
+
generated_text = ori_prompt
|
250 |
+
for new_text in streamer:
|
251 |
+
generated_text += new_text
|
252 |
+
if generated_text.endswith(stop_str):
|
253 |
+
generated_text = generated_text[:-len(stop_str)]
|
254 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
255 |
+
|
256 |
+
def generate_stream_gate(self, params):
|
257 |
+
try:
|
258 |
+
for x in self.generate_stream(params):
|
259 |
+
yield x
|
260 |
+
except ValueError as e:
|
261 |
+
print("Caught ValueError:", e)
|
262 |
+
ret = {
|
263 |
+
"text": server_error_msg,
|
264 |
+
"error_code": 1,
|
265 |
+
}
|
266 |
+
yield json.dumps(ret).encode() + b"\0"
|
267 |
+
except torch.cuda.CudaError as e:
|
268 |
+
print("Caught torch.cuda.CudaError:", e)
|
269 |
+
ret = {
|
270 |
+
"text": server_error_msg,
|
271 |
+
"error_code": 1,
|
272 |
+
}
|
273 |
+
yield json.dumps(ret).encode() + b"\0"
|
274 |
+
except Exception as e:
|
275 |
+
print("Caught Unknown Error", e)
|
276 |
+
ret = {
|
277 |
+
"text": server_error_msg,
|
278 |
+
"error_code": 1,
|
279 |
+
}
|
280 |
+
yield json.dumps(ret).encode() + b"\0"
|
281 |
+
|
282 |
+
app = FastAPI()
|
283 |
+
|
284 |
+
def release_model_semaphore(fn=None):
|
285 |
+
model_semaphore.release()
|
286 |
+
if fn is not None:
|
287 |
+
fn()
|
288 |
+
|
289 |
+
|
290 |
+
@app.post("/worker_generate_stream")
|
291 |
+
async def generate_stream(request: Request):
|
292 |
+
global model_semaphore, global_counter
|
293 |
+
global_counter += 1
|
294 |
+
params = await request.json()
|
295 |
+
|
296 |
+
if model_semaphore is None:
|
297 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
298 |
+
await model_semaphore.acquire()
|
299 |
+
worker.send_heart_beat()
|
300 |
+
generator = worker.generate_stream_gate(params)
|
301 |
+
background_tasks = BackgroundTasks()
|
302 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
303 |
+
return StreamingResponse(generator, background=background_tasks)
|
304 |
+
|
305 |
+
|
306 |
+
@app.post("/worker_get_status")
|
307 |
+
async def get_status(request: Request):
|
308 |
+
return worker.get_status()
|
309 |
+
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
parser = argparse.ArgumentParser()
|
313 |
+
parser.add_argument("--host", type=str, default="localhost")
|
314 |
+
parser.add_argument("--port", type=int, default=21002)
|
315 |
+
parser.add_argument("--worker-address", type=str,
|
316 |
+
default="http://localhost:21002")
|
317 |
+
parser.add_argument("--controller-address", type=str,
|
318 |
+
default="http://localhost:21001")
|
319 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
320 |
+
parser.add_argument("--model-base", type=str, default=None)
|
321 |
+
parser.add_argument("--model-name", type=str)
|
322 |
+
parser.add_argument("--device", type=str, default="cuda")
|
323 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
324 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
325 |
+
parser.add_argument("--no-register", action="store_true")
|
326 |
+
parser.add_argument("--load-8bit", action="store_true")
|
327 |
+
parser.add_argument("--load-4bit", action="store_true")
|
328 |
+
args = parser.parse_args()
|
329 |
+
logger.info(f"args: {args}")
|
330 |
+
|
331 |
+
|
332 |
+
worker = ModelWorker(args.controller_address,
|
333 |
+
args.worker_address,
|
334 |
+
worker_id,
|
335 |
+
args.no_register,
|
336 |
+
args.model_path,
|
337 |
+
args.model_base,
|
338 |
+
args.model_name,
|
339 |
+
args.load_8bit,
|
340 |
+
args.load_4bit,
|
341 |
+
args.device)
|
342 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
mplug_docowl/serve/model_worker_bak.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
import uuid
|
10 |
+
|
11 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
12 |
+
from fastapi.responses import StreamingResponse
|
13 |
+
import requests
|
14 |
+
import torch
|
15 |
+
import uvicorn
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
|
19 |
+
from mplug_owl2.utils import (build_logger, server_error_msg,
|
20 |
+
pretty_print_semaphore)
|
21 |
+
from mplug_owl2.model.builder import load_pretrained_model
|
22 |
+
from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
|
23 |
+
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
24 |
+
from transformers import TextIteratorStreamer
|
25 |
+
from threading import Thread
|
26 |
+
|
27 |
+
|
28 |
+
GB = 1 << 30
|
29 |
+
|
30 |
+
worker_id = str(uuid.uuid4())[:6]
|
31 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
32 |
+
global_counter = 0
|
33 |
+
|
34 |
+
model_semaphore = None
|
35 |
+
|
36 |
+
|
37 |
+
def heart_beat_worker(controller):
|
38 |
+
|
39 |
+
while True:
|
40 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
41 |
+
controller.send_heart_beat()
|
42 |
+
|
43 |
+
|
44 |
+
class ModelWorker:
|
45 |
+
def __init__(self, controller_addr, worker_addr,
|
46 |
+
worker_id, no_register,
|
47 |
+
model_path, model_base, model_name,
|
48 |
+
load_8bit, load_4bit, device):
|
49 |
+
self.controller_addr = controller_addr
|
50 |
+
self.worker_addr = worker_addr
|
51 |
+
self.worker_id = worker_id
|
52 |
+
if model_path.endswith("/"):
|
53 |
+
model_path = model_path[:-1]
|
54 |
+
if model_name is None:
|
55 |
+
model_paths = model_path.split("/")
|
56 |
+
if model_paths[-1].startswith('checkpoint-'):
|
57 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
58 |
+
else:
|
59 |
+
self.model_name = model_paths[-1]
|
60 |
+
else:
|
61 |
+
self.model_name = model_name
|
62 |
+
|
63 |
+
self.device = device
|
64 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
65 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
66 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
67 |
+
self.is_multimodal = True
|
68 |
+
|
69 |
+
if not no_register:
|
70 |
+
self.register_to_controller()
|
71 |
+
self.heart_beat_thread = threading.Thread(
|
72 |
+
target=heart_beat_worker, args=(self,))
|
73 |
+
self.heart_beat_thread.start()
|
74 |
+
|
75 |
+
def register_to_controller(self):
|
76 |
+
logger.info("Register to controller")
|
77 |
+
|
78 |
+
url = self.controller_addr + "/register_worker"
|
79 |
+
data = {
|
80 |
+
"worker_name": self.worker_addr,
|
81 |
+
"check_heart_beat": True,
|
82 |
+
"worker_status": self.get_status()
|
83 |
+
}
|
84 |
+
r = requests.post(url, json=data)
|
85 |
+
assert r.status_code == 200
|
86 |
+
|
87 |
+
def send_heart_beat(self):
|
88 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
89 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
90 |
+
f"global_counter: {global_counter}")
|
91 |
+
|
92 |
+
url = self.controller_addr + "/receive_heart_beat"
|
93 |
+
|
94 |
+
while True:
|
95 |
+
try:
|
96 |
+
ret = requests.post(url, json={
|
97 |
+
"worker_name": self.worker_addr,
|
98 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
99 |
+
exist = ret.json()["exist"]
|
100 |
+
break
|
101 |
+
except requests.exceptions.RequestException as e:
|
102 |
+
logger.error(f"heart beat error: {e}")
|
103 |
+
time.sleep(5)
|
104 |
+
|
105 |
+
if not exist:
|
106 |
+
self.register_to_controller()
|
107 |
+
|
108 |
+
def get_queue_length(self):
|
109 |
+
if model_semaphore is None:
|
110 |
+
return 0
|
111 |
+
else:
|
112 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
113 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
114 |
+
|
115 |
+
def get_status(self):
|
116 |
+
return {
|
117 |
+
"model_names": [self.model_name],
|
118 |
+
"speed": 1,
|
119 |
+
"queue_length": self.get_queue_length(),
|
120 |
+
}
|
121 |
+
|
122 |
+
@torch.inference_mode()
|
123 |
+
def generate_stream(self, params):
|
124 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
125 |
+
|
126 |
+
prompt = params["prompt"]
|
127 |
+
ori_prompt = prompt
|
128 |
+
images = params.get("images", None)
|
129 |
+
num_image_tokens = 0
|
130 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
131 |
+
if len(images) > 0:
|
132 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
133 |
+
raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
|
134 |
+
|
135 |
+
images = [load_image_from_base64(image) for image in images]
|
136 |
+
images = process_images(images, image_processor, model.config)
|
137 |
+
|
138 |
+
if type(images) is list:
|
139 |
+
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
140 |
+
else:
|
141 |
+
images = images.to(self.model.device, dtype=torch.float16)
|
142 |
+
|
143 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
144 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
145 |
+
|
146 |
+
num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
|
147 |
+
else:
|
148 |
+
images = None
|
149 |
+
image_args = {"images": images}
|
150 |
+
else:
|
151 |
+
images = None
|
152 |
+
image_args = {}
|
153 |
+
|
154 |
+
temperature = float(params.get("temperature", 1.0))
|
155 |
+
top_p = float(params.get("top_p", 1.0))
|
156 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
|
157 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
158 |
+
stop_str = params.get("stop", None)
|
159 |
+
do_sample = True if temperature > 0.001 else False
|
160 |
+
|
161 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
162 |
+
keywords = [stop_str]
|
163 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
164 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
165 |
+
|
166 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
167 |
+
|
168 |
+
if max_new_tokens < 1:
|
169 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
170 |
+
return
|
171 |
+
|
172 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
173 |
+
inputs=input_ids,
|
174 |
+
do_sample=do_sample,
|
175 |
+
temperature=temperature,
|
176 |
+
top_p=top_p,
|
177 |
+
max_new_tokens=max_new_tokens,
|
178 |
+
streamer=streamer,
|
179 |
+
stopping_criteria=[stopping_criteria],
|
180 |
+
use_cache=True,
|
181 |
+
**image_args
|
182 |
+
))
|
183 |
+
thread.start()
|
184 |
+
|
185 |
+
generated_text = ori_prompt
|
186 |
+
for new_text in streamer:
|
187 |
+
generated_text += new_text
|
188 |
+
if generated_text.endswith(stop_str):
|
189 |
+
generated_text = generated_text[:-len(stop_str)]
|
190 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
191 |
+
|
192 |
+
def generate_stream_gate(self, params):
|
193 |
+
try:
|
194 |
+
for x in self.generate_stream(params):
|
195 |
+
yield x
|
196 |
+
except ValueError as e:
|
197 |
+
print("Caught ValueError:", e)
|
198 |
+
ret = {
|
199 |
+
"text": server_error_msg,
|
200 |
+
"error_code": 1,
|
201 |
+
}
|
202 |
+
yield json.dumps(ret).encode() + b"\0"
|
203 |
+
except torch.cuda.CudaError as e:
|
204 |
+
print("Caught torch.cuda.CudaError:", e)
|
205 |
+
ret = {
|
206 |
+
"text": server_error_msg,
|
207 |
+
"error_code": 1,
|
208 |
+
}
|
209 |
+
yield json.dumps(ret).encode() + b"\0"
|
210 |
+
except Exception as e:
|
211 |
+
print("Caught Unknown Error", e)
|
212 |
+
ret = {
|
213 |
+
"text": server_error_msg,
|
214 |
+
"error_code": 1,
|
215 |
+
}
|
216 |
+
yield json.dumps(ret).encode() + b"\0"
|
217 |
+
|
218 |
+
app = FastAPI()
|
219 |
+
|
220 |
+
def release_model_semaphore(fn=None):
|
221 |
+
model_semaphore.release()
|
222 |
+
if fn is not None:
|
223 |
+
fn()
|
224 |
+
|
225 |
+
|
226 |
+
@app.post("/worker_generate_stream")
|
227 |
+
async def generate_stream(request: Request):
|
228 |
+
global model_semaphore, global_counter
|
229 |
+
global_counter += 1
|
230 |
+
params = await request.json()
|
231 |
+
|
232 |
+
if model_semaphore is None:
|
233 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
234 |
+
await model_semaphore.acquire()
|
235 |
+
worker.send_heart_beat()
|
236 |
+
generator = worker.generate_stream_gate(params)
|
237 |
+
background_tasks = BackgroundTasks()
|
238 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
239 |
+
return StreamingResponse(generator, background=background_tasks)
|
240 |
+
|
241 |
+
|
242 |
+
@app.post("/worker_get_status")
|
243 |
+
async def get_status(request: Request):
|
244 |
+
return worker.get_status()
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
parser = argparse.ArgumentParser()
|
249 |
+
parser.add_argument("--host", type=str, default="localhost")
|
250 |
+
parser.add_argument("--port", type=int, default=21002)
|
251 |
+
parser.add_argument("--worker-address", type=str,
|
252 |
+
default="http://localhost:21002")
|
253 |
+
parser.add_argument("--controller-address", type=str,
|
254 |
+
default="http://localhost:21001")
|
255 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
256 |
+
parser.add_argument("--model-base", type=str, default=None)
|
257 |
+
parser.add_argument("--model-name", type=str)
|
258 |
+
parser.add_argument("--device", type=str, default="cuda")
|
259 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
260 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
261 |
+
parser.add_argument("--no-register", action="store_true")
|
262 |
+
parser.add_argument("--load-8bit", action="store_true")
|
263 |
+
parser.add_argument("--load-4bit", action="store_true")
|
264 |
+
args = parser.parse_args()
|
265 |
+
logger.info(f"args: {args}")
|
266 |
+
|
267 |
+
|
268 |
+
worker = ModelWorker(args.controller_address,
|
269 |
+
args.worker_address,
|
270 |
+
worker_id,
|
271 |
+
args.no_register,
|
272 |
+
args.model_path,
|
273 |
+
args.model_base,
|
274 |
+
args.model_name,
|
275 |
+
args.load_8bit,
|
276 |
+
args.load_4bit,
|
277 |
+
args.device)
|
278 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
mplug_docowl/serve/register_workers.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Manually register workers.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import requests
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--controller-address", type=str)
|
15 |
+
parser.add_argument("--worker-name", type=str)
|
16 |
+
parser.add_argument("--check-heart-beat", action="store_true")
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
url = args.controller_address + "/register_worker"
|
20 |
+
data = {
|
21 |
+
"worker_name": args.worker_name,
|
22 |
+
"check_heart_beat": args.check_heart_beat,
|
23 |
+
"worker_status": None,
|
24 |
+
}
|
25 |
+
r = requests.post(url, json=data)
|
26 |
+
assert r.status_code == 200
|
mplug_docowl/train/llama_flash_attn_monkey_patch.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import transformers
|
7 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
8 |
+
|
9 |
+
try:
|
10 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
11 |
+
except ImportError:
|
12 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
13 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
14 |
+
|
15 |
+
|
16 |
+
def forward(
|
17 |
+
self,
|
18 |
+
hidden_states: torch.Tensor,
|
19 |
+
modality_indicators: torch.Tensor,
|
20 |
+
attention_mask: Optional[torch.Tensor] = None,
|
21 |
+
position_ids: Optional[torch.Tensor] = None,
|
22 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
23 |
+
output_attentions: bool = False,
|
24 |
+
use_cache: bool = False,
|
25 |
+
padding_mask: bool = None,
|
26 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
27 |
+
if output_attentions:
|
28 |
+
warnings.warn(
|
29 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
30 |
+
)
|
31 |
+
|
32 |
+
bsz, q_len, _ = hidden_states.size()
|
33 |
+
|
34 |
+
query_states = (
|
35 |
+
self.q_proj(hidden_states)
|
36 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
37 |
+
.transpose(1, 2)
|
38 |
+
)
|
39 |
+
key_states = (
|
40 |
+
self.k_proj(hidden_states, modality_indicators)
|
41 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
42 |
+
.transpose(1, 2)
|
43 |
+
)
|
44 |
+
value_states = (
|
45 |
+
self.v_proj(hidden_states, modality_indicators)
|
46 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
47 |
+
.transpose(1, 2)
|
48 |
+
) # shape: (b, num_heads, s, head_dim)
|
49 |
+
|
50 |
+
kv_seq_len = key_states.shape[-2]
|
51 |
+
if past_key_value is not None:
|
52 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
53 |
+
|
54 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
55 |
+
query_states, key_states = apply_rotary_pos_emb(
|
56 |
+
query_states, key_states, cos, sin, position_ids
|
57 |
+
)
|
58 |
+
|
59 |
+
if past_key_value is not None:
|
60 |
+
# reuse k, v
|
61 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
62 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
63 |
+
|
64 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
65 |
+
|
66 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
67 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
68 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
69 |
+
|
70 |
+
# Transform the data into the format required by flash attention
|
71 |
+
qkv = torch.stack([query_states, key_states, value_states], dim=2)
|
72 |
+
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
|
73 |
+
key_padding_mask = attention_mask
|
74 |
+
|
75 |
+
if key_padding_mask is None:
|
76 |
+
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
|
77 |
+
cu_q_lens = torch.arange(
|
78 |
+
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
79 |
+
)
|
80 |
+
max_s = q_len
|
81 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
82 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
83 |
+
)
|
84 |
+
output = output.view(bsz, q_len, -1)
|
85 |
+
else:
|
86 |
+
qkv = qkv.reshape(bsz, q_len, -1)
|
87 |
+
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
|
88 |
+
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
|
89 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
90 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
91 |
+
)
|
92 |
+
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
|
93 |
+
output = pad_input(output_unpad, indices, bsz, q_len)
|
94 |
+
|
95 |
+
return self.o_proj(output), None, past_key_value
|
96 |
+
|
97 |
+
|
98 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
99 |
+
# requires the attention mask to be the same as the key_padding_mask
|
100 |
+
def _prepare_decoder_attention_mask(
|
101 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
102 |
+
):
|
103 |
+
# [bsz, seq_len]
|
104 |
+
return attention_mask
|
105 |
+
|
106 |
+
|
107 |
+
def replace_llama_attn_with_flash_attn():
|
108 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
109 |
+
if cuda_major < 8:
|
110 |
+
warnings.warn(
|
111 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
112 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
113 |
+
)
|
114 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
115 |
+
_prepare_decoder_attention_mask
|
116 |
+
)
|
117 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
mplug_docowl/train/mplug_owl2_trainer.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from torch.utils.data import Sampler
|
5 |
+
|
6 |
+
from transformers import Trainer
|
7 |
+
from transformers.trainer import (
|
8 |
+
is_sagemaker_mp_enabled,
|
9 |
+
get_parameter_names,
|
10 |
+
has_length,
|
11 |
+
ALL_LAYERNORM_LAYERS,
|
12 |
+
ShardedDDPOption,
|
13 |
+
logger,
|
14 |
+
)
|
15 |
+
from typing import List, Optional
|
16 |
+
from icecream import ic
|
17 |
+
|
18 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
19 |
+
from deepspeed import zero
|
20 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
21 |
+
if hasattr(param, "ds_id"):
|
22 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
23 |
+
if not ignore_status:
|
24 |
+
print(name, 'no ignore status')
|
25 |
+
with zero.GatheredParameters([param]):
|
26 |
+
param = param.data.detach().cpu().clone()
|
27 |
+
else:
|
28 |
+
param = param.detach().cpu().clone()
|
29 |
+
return param
|
30 |
+
|
31 |
+
|
32 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
33 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
34 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
|
35 |
+
return to_return
|
36 |
+
|
37 |
+
|
38 |
+
def split_to_even_chunks(indices, lengths, num_chunks):
|
39 |
+
"""
|
40 |
+
Split a list of indices into `chunks` chunks of roughly equal lengths.
|
41 |
+
"""
|
42 |
+
|
43 |
+
if len(indices) % num_chunks != 0:
|
44 |
+
return [indices[i::num_chunks] for i in range(num_chunks)]
|
45 |
+
|
46 |
+
num_indices_per_chunk = len(indices) // num_chunks
|
47 |
+
|
48 |
+
chunks = [[] for _ in range(num_chunks)]
|
49 |
+
chunks_lengths = [0 for _ in range(num_chunks)]
|
50 |
+
for index in indices:
|
51 |
+
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
|
52 |
+
chunks[shortest_chunk].append(index)
|
53 |
+
chunks_lengths[shortest_chunk] += lengths[index]
|
54 |
+
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
|
55 |
+
chunks_lengths[shortest_chunk] = float("inf")
|
56 |
+
|
57 |
+
return chunks
|
58 |
+
|
59 |
+
|
60 |
+
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
|
61 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
62 |
+
assert all(l != 0 for l in lengths), "Should not have zero length."
|
63 |
+
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
|
64 |
+
# all samples are in the same modality
|
65 |
+
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
|
66 |
+
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
|
67 |
+
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
|
68 |
+
|
69 |
+
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
|
70 |
+
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
|
71 |
+
megabatch_size = world_size * batch_size
|
72 |
+
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
|
73 |
+
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
|
74 |
+
|
75 |
+
last_mm = mm_megabatches[-1]
|
76 |
+
last_lang = lang_megabatches[-1]
|
77 |
+
additional_batch = last_mm + last_lang
|
78 |
+
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
|
79 |
+
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
|
80 |
+
megabatches = [megabatches[i] for i in megabatch_indices]
|
81 |
+
|
82 |
+
if len(additional_batch) > 0:
|
83 |
+
megabatches.append(sorted(additional_batch))
|
84 |
+
|
85 |
+
return [i for megabatch in megabatches for i in megabatch]
|
86 |
+
|
87 |
+
|
88 |
+
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
|
89 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
90 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
91 |
+
megabatch_size = world_size * batch_size
|
92 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
93 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
94 |
+
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
|
95 |
+
|
96 |
+
return [i for megabatch in megabatches for batch in megabatch for i in batch]
|
97 |
+
|
98 |
+
|
99 |
+
class LengthGroupedSampler(Sampler):
|
100 |
+
r"""
|
101 |
+
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
102 |
+
keeping a bit of randomness.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
batch_size: int,
|
108 |
+
world_size: int,
|
109 |
+
lengths: Optional[List[int]] = None,
|
110 |
+
generator=None,
|
111 |
+
group_by_modality: bool = False,
|
112 |
+
):
|
113 |
+
if lengths is None:
|
114 |
+
raise ValueError("Lengths must be provided.")
|
115 |
+
|
116 |
+
self.batch_size = batch_size
|
117 |
+
self.world_size = world_size
|
118 |
+
self.lengths = lengths
|
119 |
+
self.generator = generator
|
120 |
+
self.group_by_modality = group_by_modality
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
return len(self.lengths)
|
124 |
+
|
125 |
+
def __iter__(self):
|
126 |
+
if self.group_by_modality:
|
127 |
+
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
128 |
+
else:
|
129 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
130 |
+
return iter(indices)
|
131 |
+
|
132 |
+
|
133 |
+
class MPLUGOwl2Trainer(Trainer):
|
134 |
+
|
135 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
136 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
137 |
+
return None
|
138 |
+
|
139 |
+
if self.args.group_by_modality_length:
|
140 |
+
lengths = self.train_dataset.modality_lengths
|
141 |
+
return LengthGroupedSampler(
|
142 |
+
self.args.train_batch_size,
|
143 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
|
144 |
+
lengths=lengths,
|
145 |
+
group_by_modality=True,
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
return super()._get_train_sampler()
|
149 |
+
|
150 |
+
def create_optimizer(self):
|
151 |
+
"""
|
152 |
+
Setup the optimizer.
|
153 |
+
|
154 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
155 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
156 |
+
"""
|
157 |
+
if is_sagemaker_mp_enabled():
|
158 |
+
return super().create_optimizer()
|
159 |
+
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
160 |
+
return super().create_optimizer()
|
161 |
+
|
162 |
+
opt_model = self.model
|
163 |
+
|
164 |
+
if self.optimizer is None:
|
165 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
166 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
167 |
+
if self.args.visual_abstractor_lr is not None:
|
168 |
+
projector_parameters = [name for name, _ in opt_model.named_parameters() if "visual_abstractor_lr" in name]
|
169 |
+
optimizer_grouped_parameters = [
|
170 |
+
{
|
171 |
+
"params": [
|
172 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
|
173 |
+
],
|
174 |
+
"weight_decay": self.args.weight_decay,
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"params": [
|
178 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
|
179 |
+
],
|
180 |
+
"weight_decay": 0.0,
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"params": [
|
184 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
|
185 |
+
],
|
186 |
+
"weight_decay": self.args.weight_decay,
|
187 |
+
"lr": self.args.visual_abstractor_lr,
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"params": [
|
191 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
|
192 |
+
],
|
193 |
+
"weight_decay": 0.0,
|
194 |
+
"lr": self.args.visual_abstractor_lr,
|
195 |
+
},
|
196 |
+
]
|
197 |
+
else:
|
198 |
+
optimizer_grouped_parameters = [
|
199 |
+
{
|
200 |
+
"params": [
|
201 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
|
202 |
+
],
|
203 |
+
"weight_decay": self.args.weight_decay,
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"params": [
|
207 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
208 |
+
],
|
209 |
+
"weight_decay": 0.0,
|
210 |
+
},
|
211 |
+
]
|
212 |
+
ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params']))
|
213 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
214 |
+
|
215 |
+
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
216 |
+
self.optimizer = OSS(
|
217 |
+
params=optimizer_grouped_parameters,
|
218 |
+
optim=optimizer_cls,
|
219 |
+
**optimizer_kwargs,
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
223 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
224 |
+
import bitsandbytes
|
225 |
+
|
226 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
227 |
+
|
228 |
+
skipped = 0
|
229 |
+
for module in opt_model.modules():
|
230 |
+
if isinstance(module, nn.Embedding):
|
231 |
+
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
232 |
+
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
233 |
+
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
234 |
+
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
235 |
+
logger.info(f"skipped: {skipped/2**20}M params")
|
236 |
+
|
237 |
+
return self.optimizer
|
238 |
+
|
239 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
240 |
+
super(MPLUGOwl2Trainer, self)._save_checkpoint(model, trial, metrics)
|
241 |
+
|
242 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
243 |
+
super(MPLUGOwl2Trainer, self)._save(output_dir, state_dict)
|
mplug_docowl/train/train.py
ADDED
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
2 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
3 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import copy
|
19 |
+
from dataclasses import dataclass, field
|
20 |
+
import json
|
21 |
+
import logging
|
22 |
+
import pathlib
|
23 |
+
from typing import Dict, Optional, Sequence, List
|
24 |
+
|
25 |
+
import torch
|
26 |
+
|
27 |
+
import transformers
|
28 |
+
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
|
29 |
+
|
30 |
+
from torch.utils.data import Dataset
|
31 |
+
from mplug_owl2.train.mplug_owl2_trainer import MPLUGOwl2Trainer
|
32 |
+
from mplug_owl2.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
33 |
+
|
34 |
+
from mplug_owl2 import conversation as conversation_lib
|
35 |
+
from mplug_owl2.model import *
|
36 |
+
from mplug_owl2.mm_utils import tokenizer_image_token
|
37 |
+
|
38 |
+
from PIL import Image
|
39 |
+
from icecream import ic
|
40 |
+
|
41 |
+
local_rank = None
|
42 |
+
|
43 |
+
|
44 |
+
def rank0_print(*args):
|
45 |
+
if local_rank == 0:
|
46 |
+
print(*args)
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class ModelArguments:
|
51 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
52 |
+
version: Optional[str] = field(default="v0")
|
53 |
+
freeze_backbone: bool = field(default=False)
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class DataArguments:
|
57 |
+
data_path: str = field(default=None,
|
58 |
+
metadata={"help": "Path to the training data."})
|
59 |
+
lazy_preprocess: bool = False
|
60 |
+
is_multimodal: bool = False
|
61 |
+
image_folder: Optional[str] = field(default=None)
|
62 |
+
image_aspect_ratio: str = 'square'
|
63 |
+
image_grid_pinpoints: Optional[str] = field(default=None)
|
64 |
+
|
65 |
+
|
66 |
+
@dataclass
|
67 |
+
class TrainingArguments(transformers.TrainingArguments):
|
68 |
+
cache_dir: Optional[str] = field(default=None)
|
69 |
+
optim: str = field(default="adamw_torch")
|
70 |
+
remove_unused_columns: bool = field(default=False)
|
71 |
+
|
72 |
+
tune_visual_abstractor: bool = field(default=True)
|
73 |
+
freeze_vision_model: bool = field(default=True)
|
74 |
+
|
75 |
+
model_max_length: int = field(
|
76 |
+
default=512,
|
77 |
+
metadata={
|
78 |
+
"help":
|
79 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
80 |
+
},
|
81 |
+
)
|
82 |
+
double_quant: bool = field(
|
83 |
+
default=True,
|
84 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
85 |
+
)
|
86 |
+
quant_type: str = field(
|
87 |
+
default="nf4",
|
88 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
89 |
+
)
|
90 |
+
bits: int = field(
|
91 |
+
default=16,
|
92 |
+
metadata={"help": "How many bits to use."}
|
93 |
+
)
|
94 |
+
lora_enable: bool = False
|
95 |
+
lora_r: int = 64
|
96 |
+
lora_alpha: int = 16
|
97 |
+
lora_dropout: float = 0.05
|
98 |
+
lora_weight_path: str = ""
|
99 |
+
lora_bias: str = "none"
|
100 |
+
visual_abstractor_lr: Optional[float] = None
|
101 |
+
group_by_modality_length: bool = field(default=False)
|
102 |
+
|
103 |
+
|
104 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
105 |
+
from deepspeed import zero
|
106 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
107 |
+
if hasattr(param, "ds_id"):
|
108 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
109 |
+
if not ignore_status:
|
110 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
111 |
+
with zero.GatheredParameters([param]):
|
112 |
+
param = param.data.detach().cpu().clone()
|
113 |
+
else:
|
114 |
+
param = param.detach().cpu().clone()
|
115 |
+
return param
|
116 |
+
|
117 |
+
|
118 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
119 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
120 |
+
if bias == "none":
|
121 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
122 |
+
elif bias == "all":
|
123 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
124 |
+
elif bias == "lora_only":
|
125 |
+
to_return = {}
|
126 |
+
maybe_lora_bias = {}
|
127 |
+
lora_bias_names = set()
|
128 |
+
for k, t in named_params:
|
129 |
+
if "lora_" in k:
|
130 |
+
to_return[k] = t
|
131 |
+
bias_name = k.split("lora_")[0] + "bias"
|
132 |
+
lora_bias_names.add(bias_name)
|
133 |
+
elif "bias" in k:
|
134 |
+
maybe_lora_bias[k] = t
|
135 |
+
for k, t in maybe_lora_bias:
|
136 |
+
if bias_name in lora_bias_names:
|
137 |
+
to_return[bias_name] = t
|
138 |
+
else:
|
139 |
+
raise NotImplementedError
|
140 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
141 |
+
return to_return
|
142 |
+
|
143 |
+
|
144 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
145 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
146 |
+
if require_grad_only:
|
147 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
148 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
149 |
+
return to_return
|
150 |
+
|
151 |
+
|
152 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
153 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
154 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
155 |
+
return to_return
|
156 |
+
|
157 |
+
|
158 |
+
def find_all_linear_names(model):
|
159 |
+
cls = torch.nn.Linear
|
160 |
+
lora_module_names = set()
|
161 |
+
multimodal_keywords = ['vision_model', 'visual_abstractor']
|
162 |
+
for name, module in model.named_modules():
|
163 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
164 |
+
continue
|
165 |
+
if isinstance(module, cls):
|
166 |
+
lora_module_names.add(name)
|
167 |
+
|
168 |
+
if 'lm_head' in lora_module_names: # needed for 16-bit
|
169 |
+
lora_module_names.remove('lm_head')
|
170 |
+
return list(lora_module_names)
|
171 |
+
|
172 |
+
|
173 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
|
174 |
+
output_dir: str):
|
175 |
+
"""Collects the state dict and dump to disk."""
|
176 |
+
|
177 |
+
if trainer.deepspeed:
|
178 |
+
torch.cuda.synchronize()
|
179 |
+
trainer.save_model(output_dir)
|
180 |
+
return
|
181 |
+
|
182 |
+
state_dict = trainer.model.state_dict()
|
183 |
+
if trainer.args.should_save:
|
184 |
+
cpu_state_dict = {
|
185 |
+
key: value.cpu()
|
186 |
+
for key, value in state_dict.items()
|
187 |
+
}
|
188 |
+
del state_dict
|
189 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
190 |
+
|
191 |
+
|
192 |
+
def smart_tokenizer_and_embedding_resize(
|
193 |
+
special_tokens_dict: Dict,
|
194 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
195 |
+
model: transformers.PreTrainedModel,
|
196 |
+
):
|
197 |
+
"""Resize tokenizer and embedding.
|
198 |
+
|
199 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
200 |
+
"""
|
201 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
202 |
+
model.resize_token_embeddings(len(tokenizer))
|
203 |
+
|
204 |
+
if num_new_tokens > 0:
|
205 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
206 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
207 |
+
|
208 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
209 |
+
dim=0, keepdim=True)
|
210 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
211 |
+
dim=0, keepdim=True)
|
212 |
+
|
213 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
214 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
215 |
+
|
216 |
+
|
217 |
+
def _tokenize_fn(strings: Sequence[str],
|
218 |
+
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
219 |
+
"""Tokenize a list of strings."""
|
220 |
+
tokenized_list = [
|
221 |
+
tokenizer(
|
222 |
+
text,
|
223 |
+
return_tensors="pt",
|
224 |
+
padding="longest",
|
225 |
+
max_length=tokenizer.model_max_length,
|
226 |
+
truncation=True,
|
227 |
+
) for text in strings
|
228 |
+
]
|
229 |
+
input_ids = labels = [
|
230 |
+
tokenized.input_ids[0] for tokenized in tokenized_list
|
231 |
+
]
|
232 |
+
input_ids_lens = labels_lens = [
|
233 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
234 |
+
for tokenized in tokenized_list
|
235 |
+
]
|
236 |
+
return dict(
|
237 |
+
input_ids=input_ids,
|
238 |
+
labels=labels,
|
239 |
+
input_ids_lens=input_ids_lens,
|
240 |
+
labels_lens=labels_lens,
|
241 |
+
)
|
242 |
+
|
243 |
+
|
244 |
+
def _mask_targets(target, tokenized_lens, speakers):
|
245 |
+
# cur_idx = 0
|
246 |
+
cur_idx = tokenized_lens[0]
|
247 |
+
tokenized_lens = tokenized_lens[1:]
|
248 |
+
target[:cur_idx] = IGNORE_INDEX
|
249 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
250 |
+
if speaker == "human":
|
251 |
+
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
|
252 |
+
cur_idx += tokenized_len
|
253 |
+
|
254 |
+
|
255 |
+
def _add_speaker_and_signal(header, source, get_conversation=True):
|
256 |
+
"""Add speaker and start/end signal on each round."""
|
257 |
+
BEGIN_SIGNAL = "### "
|
258 |
+
END_SIGNAL = "\n"
|
259 |
+
conversation = header
|
260 |
+
for sentence in source:
|
261 |
+
from_str = sentence["from"]
|
262 |
+
if from_str.lower() == "human":
|
263 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
264 |
+
elif from_str.lower() == "gpt":
|
265 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
266 |
+
else:
|
267 |
+
from_str = 'unknown'
|
268 |
+
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
|
269 |
+
sentence["value"] + END_SIGNAL)
|
270 |
+
if get_conversation:
|
271 |
+
conversation += sentence["value"]
|
272 |
+
conversation += BEGIN_SIGNAL
|
273 |
+
return conversation
|
274 |
+
|
275 |
+
|
276 |
+
def preprocess_multimodal(
|
277 |
+
sources: Sequence[str],
|
278 |
+
data_args: DataArguments
|
279 |
+
) -> Dict:
|
280 |
+
is_multimodal = data_args.is_multimodal
|
281 |
+
if not is_multimodal:
|
282 |
+
return sources
|
283 |
+
|
284 |
+
for source in sources:
|
285 |
+
for sentence in source:
|
286 |
+
if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
287 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
288 |
+
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
289 |
+
sentence['value'] = sentence['value'].strip()
|
290 |
+
|
291 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
292 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
293 |
+
|
294 |
+
return sources
|
295 |
+
|
296 |
+
|
297 |
+
def preprocess_v1(
|
298 |
+
sources,
|
299 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
300 |
+
has_image: bool = False
|
301 |
+
) -> Dict:
|
302 |
+
conv = conversation_lib.default_conversation.copy()
|
303 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
304 |
+
|
305 |
+
# Apply prompt templates
|
306 |
+
conversations = []
|
307 |
+
for i, source in enumerate(sources):
|
308 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
309 |
+
# Skip the first one if it is not from human
|
310 |
+
source = source[1:]
|
311 |
+
|
312 |
+
conv.messages = []
|
313 |
+
for j, sentence in enumerate(source):
|
314 |
+
role = roles[sentence["from"]]
|
315 |
+
assert role == conv.roles[j % 2], f"{i}"
|
316 |
+
conv.append_message(role, sentence["value"])
|
317 |
+
conversations.append(conv.get_prompt())
|
318 |
+
|
319 |
+
# Tokenize conversations
|
320 |
+
|
321 |
+
if has_image:
|
322 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
323 |
+
else:
|
324 |
+
input_ids = tokenizer(
|
325 |
+
conversations,
|
326 |
+
return_tensors="pt",
|
327 |
+
padding="longest",
|
328 |
+
max_length=tokenizer.model_max_length,
|
329 |
+
truncation=True,
|
330 |
+
).input_ids
|
331 |
+
|
332 |
+
targets = input_ids.clone()
|
333 |
+
|
334 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
|
335 |
+
|
336 |
+
# Mask targets
|
337 |
+
sep = conv.sep + conv.roles[1] + ": "
|
338 |
+
for conversation, target in zip(conversations, targets):
|
339 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
340 |
+
|
341 |
+
rounds = conversation.split(conv.sep2)
|
342 |
+
cur_len = 1
|
343 |
+
target[:cur_len] = IGNORE_INDEX
|
344 |
+
for i, rou in enumerate(rounds):
|
345 |
+
if rou == "":
|
346 |
+
break
|
347 |
+
|
348 |
+
parts = rou.split(sep)
|
349 |
+
if len(parts) != 2:
|
350 |
+
break
|
351 |
+
parts[0] += sep
|
352 |
+
|
353 |
+
if has_image:
|
354 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
355 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
356 |
+
else:
|
357 |
+
round_len = len(tokenizer(rou).input_ids)
|
358 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
359 |
+
|
360 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
361 |
+
|
362 |
+
cur_len += round_len
|
363 |
+
target[cur_len:] = IGNORE_INDEX
|
364 |
+
|
365 |
+
if cur_len < tokenizer.model_max_length:
|
366 |
+
if cur_len != total_len:
|
367 |
+
target[:] = IGNORE_INDEX
|
368 |
+
print(
|
369 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
370 |
+
f" (ignored)"
|
371 |
+
)
|
372 |
+
|
373 |
+
return dict(
|
374 |
+
input_ids=input_ids,
|
375 |
+
labels=targets,
|
376 |
+
)
|
377 |
+
|
378 |
+
|
379 |
+
def preprocess_plain(
|
380 |
+
sources: Sequence[str],
|
381 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
382 |
+
) -> Dict:
|
383 |
+
# add end signal and concatenate together
|
384 |
+
conversations = []
|
385 |
+
for source in sources:
|
386 |
+
assert len(source) == 2
|
387 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
|
388 |
+
source[0]['value'] = DEFAULT_IMAGE_TOKEN
|
389 |
+
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
390 |
+
conversations.append(conversation)
|
391 |
+
# tokenize conversations
|
392 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
393 |
+
targets = copy.deepcopy(input_ids)
|
394 |
+
for target, source in zip(targets, sources):
|
395 |
+
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
|
396 |
+
target[:tokenized_len] = IGNORE_INDEX
|
397 |
+
|
398 |
+
return dict(input_ids=input_ids, labels=targets)
|
399 |
+
|
400 |
+
|
401 |
+
def preprocess(
|
402 |
+
sources: Sequence[str],
|
403 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
404 |
+
has_image: bool = False
|
405 |
+
) -> Dict:
|
406 |
+
"""
|
407 |
+
Given a list of sources, each is a conversation list. This transform:
|
408 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
409 |
+
2. Concatenate conversations together;
|
410 |
+
3. Tokenize the concatenated conversation;
|
411 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
412 |
+
"""
|
413 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
414 |
+
return preprocess_plain(sources, tokenizer)
|
415 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
416 |
+
return preprocess_v1(sources, tokenizer, has_image=has_image)
|
417 |
+
# add end signal and concatenate together
|
418 |
+
conversations = []
|
419 |
+
for source in sources:
|
420 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
421 |
+
conversation = _add_speaker_and_signal(header, source)
|
422 |
+
conversations.append(conversation)
|
423 |
+
# tokenize conversations
|
424 |
+
def get_tokenize_len(prompts):
|
425 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
426 |
+
if has_image:
|
427 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
428 |
+
else:
|
429 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
430 |
+
input_ids = conversations_tokenized["input_ids"]
|
431 |
+
|
432 |
+
targets = copy.deepcopy(input_ids)
|
433 |
+
for target, source in zip(targets, sources):
|
434 |
+
if has_image:
|
435 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
436 |
+
else:
|
437 |
+
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
|
438 |
+
speakers = [sentence["from"] for sentence in source]
|
439 |
+
_mask_targets(target, tokenized_lens, speakers)
|
440 |
+
|
441 |
+
return dict(input_ids=input_ids, labels=targets)
|
442 |
+
|
443 |
+
|
444 |
+
class LazySupervisedDataset(Dataset):
|
445 |
+
"""Dataset for supervised fine-tuning."""
|
446 |
+
|
447 |
+
def __init__(self, data_path: str,
|
448 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
449 |
+
data_args: DataArguments):
|
450 |
+
super(LazySupervisedDataset, self).__init__()
|
451 |
+
list_data_dict = json.load(open(data_path, "r"))
|
452 |
+
|
453 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
454 |
+
self.tokenizer = tokenizer
|
455 |
+
self.list_data_dict = list_data_dict
|
456 |
+
self.data_args = data_args
|
457 |
+
|
458 |
+
def __len__(self):
|
459 |
+
return len(self.list_data_dict)
|
460 |
+
|
461 |
+
@property
|
462 |
+
def lengths(self):
|
463 |
+
length_list = []
|
464 |
+
for sample in self.list_data_dict:
|
465 |
+
img_tokens = 128 if 'image' in sample else 0
|
466 |
+
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
|
467 |
+
return length_list
|
468 |
+
|
469 |
+
|
470 |
+
@property
|
471 |
+
def modality_lengths(self):
|
472 |
+
length_list = []
|
473 |
+
for sample in self.list_data_dict:
|
474 |
+
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
|
475 |
+
cur_len = cur_len if 'image' in sample else -cur_len
|
476 |
+
length_list.append(cur_len)
|
477 |
+
return length_list
|
478 |
+
|
479 |
+
# def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
480 |
+
# sources = self.list_data_dict[i]
|
481 |
+
# if isinstance(i, int):
|
482 |
+
# sources = [sources]
|
483 |
+
# assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
484 |
+
# if 'image' in sources[0]:
|
485 |
+
# image_file = self.list_data_dict[i]['image']
|
486 |
+
# image_folder = self.data_args.image_folder
|
487 |
+
# processor = self.data_args.image_processor
|
488 |
+
# image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
|
489 |
+
# if self.data_args.image_aspect_ratio == 'pad':
|
490 |
+
# def expand2square(pil_img, background_color):
|
491 |
+
# width, height = pil_img.size
|
492 |
+
# if width == height:
|
493 |
+
# return pil_img
|
494 |
+
# elif width > height:
|
495 |
+
# result = Image.new(pil_img.mode, (width, width), background_color)
|
496 |
+
# result.paste(pil_img, (0, (width - height) // 2))
|
497 |
+
# return result
|
498 |
+
# else:
|
499 |
+
# result = Image.new(pil_img.mode, (height, height), background_color)
|
500 |
+
# result.paste(pil_img, ((height - width) // 2, 0))
|
501 |
+
# return result
|
502 |
+
# image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
503 |
+
# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
504 |
+
# else:
|
505 |
+
# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
506 |
+
# sources = preprocess_multimodal(
|
507 |
+
# copy.deepcopy([e["conversations"] for e in sources]),
|
508 |
+
# self.data_args)
|
509 |
+
# else:
|
510 |
+
# sources = copy.deepcopy([e["conversations"] for e in sources])
|
511 |
+
# data_dict = preprocess(
|
512 |
+
# sources,
|
513 |
+
# self.tokenizer,
|
514 |
+
# has_image=('image' in self.list_data_dict[i]))
|
515 |
+
# if isinstance(i, int):
|
516 |
+
# data_dict = dict(input_ids=data_dict["input_ids"][0],
|
517 |
+
# labels=data_dict["labels"][0])
|
518 |
+
|
519 |
+
# # image exist in the data
|
520 |
+
# if 'image' in self.list_data_dict[i]:
|
521 |
+
# data_dict['image'] = image
|
522 |
+
# elif self.data_args.is_multimodal:
|
523 |
+
# # image does not exist in the data, but the model is multimodal
|
524 |
+
# crop_size = self.data_args.image_processor.crop_size
|
525 |
+
# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
|
526 |
+
# return data_dict
|
527 |
+
|
528 |
+
def next_rand(self):
|
529 |
+
import random
|
530 |
+
return random.randint(0,len(self)-1)
|
531 |
+
|
532 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
533 |
+
while True:
|
534 |
+
sources = self.list_data_dict[i]
|
535 |
+
if isinstance(i, int):
|
536 |
+
sources = [sources]
|
537 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
538 |
+
if 'image' in sources[0]:
|
539 |
+
|
540 |
+
image_file = self.list_data_dict[i]['image']
|
541 |
+
image_folder = self.data_args.image_folder
|
542 |
+
processor = self.data_args.image_processor
|
543 |
+
from pathlib import Path
|
544 |
+
if not Path(os.path.join(image_folder, image_file)).exists():
|
545 |
+
i = self.next_rand()
|
546 |
+
continue
|
547 |
+
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
|
548 |
+
if self.data_args.image_aspect_ratio == 'pad':
|
549 |
+
def expand2square(pil_img, background_color):
|
550 |
+
width, height = pil_img.size
|
551 |
+
if width == height:
|
552 |
+
return pil_img
|
553 |
+
elif width > height:
|
554 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
555 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
556 |
+
return result
|
557 |
+
else:
|
558 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
559 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
560 |
+
return result
|
561 |
+
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
562 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
563 |
+
else:
|
564 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
565 |
+
sources = preprocess_multimodal(
|
566 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
567 |
+
self.data_args)
|
568 |
+
else:
|
569 |
+
|
570 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
571 |
+
data_dict = preprocess(
|
572 |
+
sources,
|
573 |
+
self.tokenizer,
|
574 |
+
has_image=('image' in self.list_data_dict[i]))
|
575 |
+
if isinstance(i, int):
|
576 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
577 |
+
labels=data_dict["labels"][0])
|
578 |
+
|
579 |
+
# image exist in the data
|
580 |
+
if 'image' in self.list_data_dict[i]:
|
581 |
+
data_dict['image'] = image
|
582 |
+
elif self.data_args.is_multimodal:
|
583 |
+
# image does not exist in the data, but the model is multimodal
|
584 |
+
crop_size = self.data_args.image_processor.crop_size
|
585 |
+
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
|
586 |
+
return data_dict
|
587 |
+
|
588 |
+
|
589 |
+
@dataclass
|
590 |
+
class DataCollatorForSupervisedDataset(object):
|
591 |
+
"""Collate examples for supervised fine-tuning."""
|
592 |
+
|
593 |
+
tokenizer: transformers.PreTrainedTokenizer
|
594 |
+
|
595 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
596 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
597 |
+
for key in ("input_ids", "labels"))
|
598 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
599 |
+
input_ids,
|
600 |
+
batch_first=True,
|
601 |
+
padding_value=self.tokenizer.pad_token_id)
|
602 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
603 |
+
batch_first=True,
|
604 |
+
padding_value=IGNORE_INDEX)
|
605 |
+
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
606 |
+
labels = labels[:, :self.tokenizer.model_max_length]
|
607 |
+
batch = dict(
|
608 |
+
input_ids=input_ids,
|
609 |
+
labels=labels,
|
610 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
611 |
+
)
|
612 |
+
|
613 |
+
if 'image' in instances[0]:
|
614 |
+
images = [instance['image'] for instance in instances]
|
615 |
+
if all(x is not None and x.shape == images[0].shape for x in images):
|
616 |
+
batch['images'] = torch.stack(images)
|
617 |
+
else:
|
618 |
+
batch['images'] = images
|
619 |
+
|
620 |
+
return batch
|
621 |
+
|
622 |
+
|
623 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
624 |
+
data_args) -> Dict:
|
625 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
626 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
627 |
+
data_path=data_args.data_path,
|
628 |
+
data_args=data_args)
|
629 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
630 |
+
return dict(train_dataset=train_dataset,
|
631 |
+
eval_dataset=None,
|
632 |
+
data_collator=data_collator)
|
633 |
+
|
634 |
+
|
635 |
+
def train():
|
636 |
+
global local_rank
|
637 |
+
|
638 |
+
parser = transformers.HfArgumentParser(
|
639 |
+
(ModelArguments, DataArguments, TrainingArguments))
|
640 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
641 |
+
local_rank = training_args.local_rank
|
642 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
643 |
+
|
644 |
+
bnb_model_from_pretrained_args = {}
|
645 |
+
if training_args.bits in [4, 8]:
|
646 |
+
from transformers import BitsAndBytesConfig
|
647 |
+
bnb_model_from_pretrained_args.update(dict(
|
648 |
+
device_map={"": training_args.device},
|
649 |
+
load_in_4bit=training_args.bits == 4,
|
650 |
+
load_in_8bit=training_args.bits == 8,
|
651 |
+
quantization_config=BitsAndBytesConfig(
|
652 |
+
load_in_4bit=training_args.bits == 4,
|
653 |
+
load_in_8bit=training_args.bits == 8,
|
654 |
+
llm_int8_threshold=6.0,
|
655 |
+
llm_int8_has_fp16_weight=False,
|
656 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
657 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
658 |
+
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
659 |
+
)
|
660 |
+
))
|
661 |
+
|
662 |
+
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
|
663 |
+
model_args.model_name_or_path,
|
664 |
+
cache_dir=training_args.cache_dir,
|
665 |
+
**bnb_model_from_pretrained_args
|
666 |
+
)
|
667 |
+
model.config.use_cache = False
|
668 |
+
|
669 |
+
if model_args.freeze_backbone:
|
670 |
+
model.model.requires_grad_(False)
|
671 |
+
|
672 |
+
if training_args.bits in [4, 8]:
|
673 |
+
from peft import prepare_model_for_kbit_training
|
674 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
675 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
676 |
+
|
677 |
+
if training_args.gradient_checkpointing:
|
678 |
+
if hasattr(model, "enable_input_require_grads"):
|
679 |
+
model.enable_input_require_grads()
|
680 |
+
else:
|
681 |
+
def make_inputs_require_grad(module, input, output):
|
682 |
+
output.requires_grad_(True)
|
683 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
684 |
+
|
685 |
+
if training_args.lora_enable:
|
686 |
+
from peft import LoraConfig, get_peft_model
|
687 |
+
lora_config = LoraConfig(
|
688 |
+
r=training_args.lora_r,
|
689 |
+
lora_alpha=training_args.lora_alpha,
|
690 |
+
target_modules=find_all_linear_names(model),
|
691 |
+
lora_dropout=training_args.lora_dropout,
|
692 |
+
bias=training_args.lora_bias,
|
693 |
+
task_type="CAUSAL_LM",
|
694 |
+
)
|
695 |
+
if training_args.bits == 16:
|
696 |
+
if training_args.bf16:
|
697 |
+
model.to(torch.bfloat16)
|
698 |
+
if training_args.fp16:
|
699 |
+
model.to(torch.float16)
|
700 |
+
rank0_print("Adding LoRA adapters...")
|
701 |
+
model = get_peft_model(model, lora_config)
|
702 |
+
|
703 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
704 |
+
model_args.model_name_or_path,
|
705 |
+
cache_dir=training_args.cache_dir,
|
706 |
+
model_max_length=training_args.model_max_length,
|
707 |
+
padding_side="right",
|
708 |
+
use_fast=False,
|
709 |
+
)
|
710 |
+
|
711 |
+
|
712 |
+
tokenizer.pad_token = tokenizer.unk_token
|
713 |
+
if model_args.version in conversation_lib.conv_templates:
|
714 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
715 |
+
else:
|
716 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
|
717 |
+
|
718 |
+
if not training_args.freeze_vision_model and training_args.bits in [4, 8]:
|
719 |
+
model.get_model().vision_model.to(dtype=compute_dtype, device=training_args.device)
|
720 |
+
else:
|
721 |
+
vision_tower = model.get_model().vision_model
|
722 |
+
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
|
723 |
+
|
724 |
+
if training_args.tune_visual_abstractor and training_args.bits in [4, 8]:
|
725 |
+
model.get_model().visual_abstractor.to(dtype=compute_dtype, device=training_args.device)
|
726 |
+
else:
|
727 |
+
visual_abstractor = model.get_model().visual_abstractor
|
728 |
+
visual_abstractor.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
|
729 |
+
|
730 |
+
data_args.image_processor = CLIPImageProcessor.from_pretrained(model_args.model_name_or_path)
|
731 |
+
data_args.is_multimodal = True
|
732 |
+
|
733 |
+
model.config.image_aspect_ratio = data_args.image_aspect_ratio
|
734 |
+
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
|
735 |
+
model.config.tune_visual_abstractor = model_args.tune_visual_abstractor = training_args.tune_visual_abstractor
|
736 |
+
ic(training_args.tune_visual_abstractor)
|
737 |
+
model.requires_grad_(True)
|
738 |
+
if training_args.tune_visual_abstractor:
|
739 |
+
# model.requires_grad_(False)
|
740 |
+
for p in model.get_model().visual_abstractor.parameters():
|
741 |
+
p.requires_grad = True
|
742 |
+
|
743 |
+
model.config.freeze_vision_model = training_args.freeze_vision_model
|
744 |
+
ic(training_args.freeze_vision_model)
|
745 |
+
if training_args.freeze_vision_model:
|
746 |
+
for p in model.get_model().vision_model.parameters():
|
747 |
+
p.requires_grad = False
|
748 |
+
|
749 |
+
model.config.visual_abstractor_lr = training_args.visual_abstractor_lr
|
750 |
+
|
751 |
+
|
752 |
+
if training_args.bits in [4, 8]:
|
753 |
+
from peft.tuners.lora import LoraLayer
|
754 |
+
for name, module in model.named_modules():
|
755 |
+
if isinstance(module, LoraLayer):
|
756 |
+
if training_args.bf16:
|
757 |
+
module = module.to(torch.bfloat16)
|
758 |
+
if 'norm' in name:
|
759 |
+
module = module.to(torch.float32)
|
760 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
761 |
+
if hasattr(module, 'weight'):
|
762 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
763 |
+
module = module.to(torch.bfloat16)
|
764 |
+
|
765 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
766 |
+
data_args=data_args)
|
767 |
+
trainer = MPLUGOwl2Trainer(model=model,
|
768 |
+
tokenizer=tokenizer,
|
769 |
+
args=training_args,
|
770 |
+
**data_module)
|
771 |
+
|
772 |
+
# if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
773 |
+
# trainer.train(resume_from_checkpoint=True)
|
774 |
+
# else:
|
775 |
+
# trainer.train()
|
776 |
+
|
777 |
+
# TODO I dont like auto resume << REMOVE IT AND UNCOMMENT THE ABOVE CODE
|
778 |
+
trainer.train()
|
779 |
+
|
780 |
+
trainer.save_state()
|
781 |
+
|
782 |
+
model.config.use_cache = True
|
783 |
+
|
784 |
+
if training_args.lora_enable:
|
785 |
+
state_dict = get_peft_state_maybe_zero_3(
|
786 |
+
model.named_parameters(), training_args.lora_bias
|
787 |
+
)
|
788 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
|
789 |
+
model.named_parameters()
|
790 |
+
)
|
791 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
792 |
+
model.config.save_pretrained(training_args.output_dir)
|
793 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
794 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
795 |
+
else:
|
796 |
+
safe_save_model_for_hf_trainer(trainer=trainer,
|
797 |
+
output_dir=training_args.output_dir)
|
798 |
+
|
799 |
+
|
800 |
+
if __name__ == "__main__":
|
801 |
+
train()
|