tinyllava commited on
Commit
58845c3
1 Parent(s): c1585b5

upload python file

Browse files
Files changed (2) hide show
  1. configuration.py +116 -0
  2. generate_model.py +730 -0
configuration.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers import CONFIG_MAPPING
3
+ from transformers import AutoConfig
4
+
5
+ IGNORE_INDEX = -100
6
+ IMAGE_TOKEN_INDEX = -200
7
+ DEFAULT_IMAGE_TOKEN = "<image>"
8
+
9
+
10
+ class TinyLlavaConfig(PretrainedConfig):
11
+
12
+ model_type = "tinyllava"
13
+ def __init__(
14
+ self,
15
+ llm_model_name_or_path = '',
16
+ tokenizer_name_or_path = None,
17
+ vision_model_name_or_path = '',
18
+ vision_model_name_or_path2 = '',
19
+ connector_type = None,
20
+ text_config=None,
21
+ hidden_size=2048,
22
+ vocab_size=32000,
23
+ ignore_index=-100,
24
+ image_token_index=32000,
25
+ pad_token = None,
26
+ pad_token_id = None,
27
+ tokenizer_padding_side = 'right',
28
+ tokenizer_model_max_length = 2048,
29
+ vision_config = None,
30
+ vision_hidden_size = None,
31
+ vision_feature_layer = -2,
32
+ vision_feature_select_strategy = 'patch',
33
+ image_aspect_ratio = 'square',
34
+ resampler_hidden_size = None,
35
+ num_queries = None,
36
+ num_resampler_layers = None,
37
+ use_cache = False,
38
+ cache_dir = None,
39
+ tokenizer_use_fast = False,
40
+ tune_type_llm = 'frozen',
41
+ tune_type_connector = 'frozen',
42
+ tune_type_vision_tower = 'frozen',
43
+ tune_vision_tower_from_layer = -1,
44
+
45
+ **kwargs
46
+
47
+ ):
48
+ self.llm_model_name_or_path = llm_model_name_or_path
49
+ self.tokenizer_name_or_path = tokenizer_name_or_path or self.llm_model_name_or_path
50
+ self.vision_model_name_or_path = vision_model_name_or_path
51
+ self.vision_model_name_or_path2 = vision_model_name_or_path2
52
+ self.connector_type = connector_type
53
+ self.tune_type_llm = tune_type_llm
54
+ self.tune_type_connector = tune_type_connector
55
+ self.tune_type_vision_tower = tune_type_vision_tower
56
+ self.tune_vision_tower_from_layer = tune_vision_tower_from_layer
57
+
58
+ self.ignore_index = IGNORE_INDEX
59
+ self.image_token_index = IMAGE_TOKEN_INDEX
60
+ self.pad_token = pad_token
61
+ self.pad_token_id = pad_token_id
62
+ self.tokenizer_padding_side = tokenizer_padding_side
63
+ self.tokenizer_model_max_length = tokenizer_model_max_length
64
+ self.vision_feature_layer = vision_feature_layer
65
+ self.vision_feature_select_strategy = vision_feature_select_strategy
66
+ self.image_aspect_ratio = image_aspect_ratio
67
+ self.resampler_hidden_size = resampler_hidden_size
68
+ self.num_queries = num_queries
69
+ self.num_resampler_layers = num_resampler_layers
70
+ self.use_cache = use_cache
71
+ self.cache_dir = cache_dir
72
+ self.tokenizer_use_fast = tokenizer_use_fast
73
+ self._load_text_config(text_config)
74
+ self._load_vision_config(vision_config)
75
+
76
+ super().__init__(**kwargs)
77
+
78
+
79
+ def _load_text_config(self, text_config=None):
80
+ if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
81
+ self.text_config = CONFIG_MAPPING['llama']()
82
+
83
+ else:
84
+ self.text_config = AutoConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True)
85
+ if text_config is not None:
86
+ self.text_config = self.text_config.from_dict(text_config)
87
+
88
+ self.hidden_size = getattr(self.text_config, 'hidden_size', getattr(self.text_config, 'model_dim', None))
89
+ self.vocab_size = getattr(self.text_config, 'vocab_size', None)
90
+
91
+
92
+
93
+ def _load_vision_config(self, vision_config=None):
94
+ if self.vision_model_name_or_path is None or self.vision_model_name_or_path == '':
95
+ self.vision_config = CONFIG_MAPPING['clip_vision_model'](
96
+ intermediate_size=4096,
97
+ hidden_size=1024,
98
+ patch_size=14,
99
+ image_size=336,
100
+ num_hidden_layers=24,
101
+ num_attention_heads=16,
102
+ vocab_size=32000,
103
+ projection_dim=768,
104
+ )
105
+
106
+ else:
107
+ self.vision_config = AutoConfig.from_pretrained(self.vision_model_name_or_path.split(':')[-1])
108
+ self.vision_config = getattr(self.vision_config, 'vision_config', self.vision_config)
109
+ if vision_config is not None:
110
+ self.vision_config = self.vision_config.from_dict(vision_config)
111
+
112
+ self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
113
+ self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
114
+ self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
115
+
116
+
generate_model.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import logging
4
+ import requests
5
+ import os
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ from PIL import Image
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+
15
+ from PIL import Image
16
+ from io import BytesIO
17
+ import base64
18
+
19
+ import torch
20
+ from transformers import StoppingCriteria
21
+
22
+ import math
23
+ import ast
24
+
25
+ # Model Constants
26
+ IGNORE_INDEX = -100
27
+ IMAGE_TOKEN_INDEX = -200
28
+ DEFAULT_IMAGE_TOKEN = "<image>"
29
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
30
+ DEFAULT_IM_START_TOKEN = "<im_start>"
31
+ DEFAULT_IM_END_TOKEN = "<im_end>"
32
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
33
+ import dataclasses
34
+ from enum import auto, Enum
35
+ from typing import List, Tuple
36
+
37
+
38
+ class SeparatorStyle(Enum):
39
+ """Different separator style."""
40
+ SINGLE = auto()
41
+ TWO = auto()
42
+ MPT = auto()
43
+ PLAIN = auto()
44
+ LLAMA_2 = auto()
45
+ TINY_LLAMA = auto()
46
+ QWEN_2 = auto()
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class Conversation:
51
+ """A class that keeps all conversation history."""
52
+ system: str
53
+ roles: List[str]
54
+ messages: List[List[str]]
55
+ offset: int
56
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
57
+ sep: str = "###"
58
+ sep2: str = None
59
+ version: str = "Unknown"
60
+
61
+ skip_next: bool = False
62
+
63
+ def get_prompt(self):
64
+ messages = self.messages
65
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
66
+ messages = self.messages.copy()
67
+ init_role, init_msg = messages[0].copy()
68
+ init_msg = init_msg[0].replace("<image>", "").strip()
69
+ if 'mmtag' in self.version:
70
+ messages[0] = (init_role, init_msg)
71
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
72
+ messages.insert(1, (self.roles[1], "Received."))
73
+ else:
74
+ messages[0] = (init_role, "<image>\n" + init_msg)
75
+
76
+ if self.sep_style == SeparatorStyle.SINGLE:
77
+ ret = self.system + self.sep
78
+ for role, message in messages:
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ ret += role + ": " + message + self.sep
83
+ else:
84
+ ret += role + ":"
85
+ elif self.sep_style == SeparatorStyle.TWO:
86
+ seps = [self.sep, self.sep2]
87
+ ret = self.system + seps[0]
88
+ for i, (role, message) in enumerate(messages):
89
+ if message:
90
+ if type(message) is tuple:
91
+ message, _, _ = message
92
+ ret += role + ": " + message + seps[i % 2]
93
+ else:
94
+ ret += role + ":"
95
+ elif self.sep_style == SeparatorStyle.MPT:
96
+ ret = self.system + self.sep
97
+ for role, message in messages:
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += role + message + self.sep
102
+ else:
103
+ ret += role
104
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
105
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
106
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
107
+ ret = ""
108
+
109
+ for i, (role, message) in enumerate(messages):
110
+ if i == 0:
111
+ assert message, "first message should not be none"
112
+ assert role == self.roles[0], "first message should come from user"
113
+ if message:
114
+ if type(message) is tuple:
115
+ message, _, _ = message
116
+ if i == 0: message = wrap_sys(self.system) + message
117
+ if i % 2 == 0:
118
+ message = wrap_inst(message)
119
+ ret += self.sep + message
120
+ else:
121
+ ret += " " + message + " " + self.sep2
122
+ else:
123
+ ret += ""
124
+ ret = ret.lstrip(self.sep)
125
+ elif self.sep_style == SeparatorStyle.TINY_LLAMA:
126
+ sep = "</s>"
127
+ wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
128
+ wrap_user = lambda msg: f"<|user|>\n{msg}\n"
129
+ wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
130
+ ret = ""
131
+
132
+ for i, (role, message) in enumerate(messages):
133
+ if i == 0:
134
+ assert message, "first message should not be none"
135
+ assert role == self.roles[0], "first message should come from user"
136
+ if message:
137
+ if type(message) is tuple:
138
+ message, _, _ = message
139
+ if i % 2 == 0:
140
+ message = wrap_user(message)
141
+ if i == 0:
142
+ message = wrap_sys(self.system) + message
143
+ ret += self.sep + message
144
+ else:
145
+ message = wrap_assistant(message) + self.sep2
146
+ ret += message
147
+ else:
148
+ ret += "<|assistant|>\n"
149
+ ret = ret.lstrip(self.sep)
150
+ elif self.sep_style == SeparatorStyle.QWEN_2:
151
+ ret = self.system + self.sep
152
+ for role, message in messages:
153
+ if message:
154
+ if type(message) is tuple:
155
+ message, _, _ = message
156
+ ret += role + message + self.sep
157
+ else:
158
+ ret += role
159
+ elif self.sep_style == SeparatorStyle.PLAIN:
160
+ seps = [self.sep, self.sep2]
161
+ ret = self.system
162
+ for i, (role, message) in enumerate(messages):
163
+ if message:
164
+ if type(message) is tuple:
165
+ message, _, _ = message
166
+ ret += message + seps[i % 2]
167
+ else:
168
+ ret += ""
169
+ else:
170
+ raise ValueError(f"Invalid style: {self.sep_style}")
171
+
172
+ return ret
173
+
174
+ def append_message(self, role, message):
175
+ self.messages.append([role, message])
176
+
177
+ def get_images(self, return_pil=False):
178
+ images = []
179
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
180
+ if i % 2 == 0:
181
+ if type(msg) is tuple:
182
+ import base64
183
+ from io import BytesIO
184
+ from PIL import Image
185
+ msg, image, image_process_mode = msg
186
+ if image_process_mode == "Pad":
187
+ def expand2square(pil_img, background_color=(122, 116, 104)):
188
+ width, height = pil_img.size
189
+ if width == height:
190
+ return pil_img
191
+ elif width > height:
192
+ result = Image.new(pil_img.mode, (width, width), background_color)
193
+ result.paste(pil_img, (0, (width - height) // 2))
194
+ return result
195
+ else:
196
+ result = Image.new(pil_img.mode, (height, height), background_color)
197
+ result.paste(pil_img, ((height - width) // 2, 0))
198
+ return result
199
+ image = expand2square(image)
200
+ elif image_process_mode in ["Default", "Crop"]:
201
+ pass
202
+ elif image_process_mode == "Resize":
203
+ image = image.resize((336, 336))
204
+ else:
205
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
206
+ max_hw, min_hw = max(image.size), min(image.size)
207
+ aspect_ratio = max_hw / min_hw
208
+ max_len, min_len = 800, 400
209
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
210
+ longest_edge = int(shortest_edge * aspect_ratio)
211
+ W, H = image.size
212
+ if longest_edge != max(image.size):
213
+ if H > W:
214
+ H, W = longest_edge, shortest_edge
215
+ else:
216
+ H, W = shortest_edge, longest_edge
217
+ image = image.resize((W, H))
218
+ if return_pil:
219
+ images.append(image)
220
+ else:
221
+ buffered = BytesIO()
222
+ image.save(buffered, format="PNG")
223
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
224
+ images.append(img_b64_str)
225
+ return images
226
+
227
+ def to_gradio_chatbot(self):
228
+ ret = []
229
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
230
+ if i % 2 == 0:
231
+ if type(msg) is tuple:
232
+ import base64
233
+ from io import BytesIO
234
+ msg, image, image_process_mode = msg
235
+ max_hw, min_hw = max(image.size), min(image.size)
236
+ aspect_ratio = max_hw / min_hw
237
+ max_len, min_len = 800, 400
238
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
239
+ longest_edge = int(shortest_edge * aspect_ratio)
240
+ W, H = image.size
241
+ if H > W:
242
+ H, W = longest_edge, shortest_edge
243
+ else:
244
+ H, W = shortest_edge, longest_edge
245
+ image = image.resize((W, H))
246
+ buffered = BytesIO()
247
+ image.save(buffered, format="JPEG")
248
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
249
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
250
+ msg = img_str + msg.replace('<image>', '').strip()
251
+ ret.append([msg, None])
252
+ else:
253
+ ret.append([msg, None])
254
+ else:
255
+ ret[-1][-1] = msg
256
+ return ret
257
+
258
+ def copy(self):
259
+ return Conversation(
260
+ system=self.system,
261
+ roles=self.roles,
262
+ messages=[[x, y] for x, y in self.messages],
263
+ offset=self.offset,
264
+ sep_style=self.sep_style,
265
+ sep=self.sep,
266
+ sep2=self.sep2,
267
+ version=self.version)
268
+
269
+ def dict(self):
270
+ if len(self.get_images()) > 0:
271
+ return {
272
+ "system": self.system,
273
+ "roles": self.roles,
274
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
275
+ "offset": self.offset,
276
+ "sep": self.sep,
277
+ "sep2": self.sep2,
278
+ }
279
+ return {
280
+ "system": self.system,
281
+ "roles": self.roles,
282
+ "messages": self.messages,
283
+ "offset": self.offset,
284
+ "sep": self.sep,
285
+ "sep2": self.sep2,
286
+ }
287
+
288
+
289
+
290
+
291
+ conv_phi_v0 = Conversation(
292
+ system="A chat between a curious user and an artificial intelligence assistant. "
293
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
294
+ roles=("USER", "ASSISTANT"),
295
+ version="phi",
296
+ messages=(),
297
+ offset=0,
298
+ sep_style=SeparatorStyle.TWO,
299
+ sep=" ",
300
+ sep2="<|endoftext|>",
301
+ )
302
+
303
+
304
+
305
+ def select_best_resolution(original_size, possible_resolutions):
306
+ """
307
+ Selects the best resolution from a list of possible resolutions based on the original size.
308
+
309
+ Args:
310
+ original_size (tuple): The original size of the image in the format (width, height).
311
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
312
+
313
+ Returns:
314
+ tuple: The best fit resolution in the format (width, height).
315
+ """
316
+ original_width, original_height = original_size
317
+ best_fit = None
318
+ max_effective_resolution = 0
319
+ min_wasted_resolution = float('inf')
320
+
321
+ for width, height in possible_resolutions:
322
+ scale = min(width / original_width, height / original_height)
323
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
324
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
325
+ wasted_resolution = (width * height) - effective_resolution
326
+
327
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
328
+ max_effective_resolution = effective_resolution
329
+ min_wasted_resolution = wasted_resolution
330
+ best_fit = (width, height)
331
+
332
+ return best_fit
333
+
334
+
335
+ ## added by llava-1.6
336
+ def resize_and_pad_image(image, target_resolution):
337
+ """
338
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
339
+
340
+ Args:
341
+ image (PIL.Image.Image): The input image.
342
+ target_resolution (tuple): The target resolution (width, height) of the image.
343
+
344
+ Returns:
345
+ PIL.Image.Image: The resized and padded image.
346
+ """
347
+ original_width, original_height = image.size
348
+ target_width, target_height = target_resolution
349
+
350
+ scale_w = target_width / original_width
351
+ scale_h = target_height / original_height
352
+
353
+ if scale_w < scale_h:
354
+ new_width = target_width
355
+ new_height = min(math.ceil(original_height * scale_w), target_height)
356
+ else:
357
+ new_height = target_height
358
+ new_width = min(math.ceil(original_width * scale_h), target_width)
359
+
360
+ # Resize the image
361
+ resized_image = image.resize((new_width, new_height))
362
+
363
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
364
+ paste_x = (target_width - new_width) // 2
365
+ paste_y = (target_height - new_height) // 2
366
+ new_image.paste(resized_image, (paste_x, paste_y))
367
+
368
+ return new_image
369
+
370
+
371
+ ## added by llava-1.6
372
+ def divide_to_patches(image, patch_size):
373
+ """
374
+ Divides an image into patches of a specified size.
375
+
376
+ Args:
377
+ image (PIL.Image.Image): The input image.
378
+ patch_size (int): The size of each patch.
379
+
380
+ Returns:
381
+ list: A list of PIL.Image.Image objects representing the patches.
382
+ """
383
+ patches = []
384
+ width, height = image.size
385
+ for i in range(0, height, patch_size):
386
+ for j in range(0, width, patch_size):
387
+ box = (j, i, j + patch_size, i + patch_size)
388
+ patch = image.crop(box)
389
+ patches.append(patch)
390
+
391
+ return patches
392
+
393
+
394
+ ## added by llava-1.6
395
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
396
+ """
397
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
398
+
399
+ Args:
400
+ image_size (tuple): The size of the input image in the format (width, height).
401
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
402
+ patch_size (int): The size of each image patch.
403
+
404
+ Returns:
405
+ tuple: The shape of the image patch grid in the format (width, height).
406
+ """
407
+ if type(grid_pinpoints) is list:
408
+ possible_resolutions = grid_pinpoints
409
+ else:
410
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
411
+ width, height = select_best_resolution(image_size, possible_resolutions)
412
+ return width // patch_size, height // patch_size
413
+
414
+
415
+ ## added by llava-1.6
416
+ def process_anyres_image(image, processor, grid_pinpoints):
417
+ """
418
+ Process an image with variable resolutions.
419
+
420
+ Args:
421
+ image (PIL.Image.Image): The input image to be processed.
422
+ processor: The image processor object.
423
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
424
+
425
+ Returns:
426
+ torch.Tensor: A tensor containing the processed image patches.
427
+ """
428
+ if type(grid_pinpoints) is list:
429
+ possible_resolutions = grid_pinpoints
430
+ else:
431
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
432
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
433
+ image_padded = resize_and_pad_image(image, best_resolution)
434
+
435
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
436
+
437
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
438
+
439
+ image_patches = [image_original_resize] + patches
440
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
441
+ for image_patch in image_patches]
442
+ return torch.stack(image_patches, dim=0)
443
+
444
+
445
+ def load_image_from_base64(image):
446
+ return Image.open(BytesIO(base64.b64decode(image)))
447
+
448
+
449
+ def expand2square(pil_img, background_color):
450
+ width, height = pil_img.size
451
+ if width == height:
452
+ return pil_img
453
+ elif width > height:
454
+ result = Image.new(pil_img.mode, (width, width), background_color)
455
+ result.paste(pil_img, (0, (width - height) // 2))
456
+ return result
457
+ else:
458
+ result = Image.new(pil_img.mode, (height, height), background_color)
459
+ result.paste(pil_img, ((height - width) // 2, 0))
460
+ return result
461
+
462
+
463
+ def process_images(images, image_processor, model_cfg):
464
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
465
+ new_images = []
466
+ if image_aspect_ratio == 'pad':
467
+ for image in images:
468
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
469
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
470
+ new_images.append(image)
471
+ elif image_aspect_ratio == "anyres":
472
+ for image in images:
473
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
474
+ new_images.append(image)
475
+ else:
476
+ return image_processor(images, return_tensors='pt')['pixel_values']
477
+ if all(x.shape == new_images[0].shape for x in new_images):
478
+ new_images = torch.stack(new_images, dim=0)
479
+ return new_images
480
+
481
+
482
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
483
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
484
+
485
+ def insert_separator(X, sep):
486
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
487
+
488
+ input_ids = []
489
+ offset = 0
490
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
491
+ offset = 1
492
+ input_ids.append(prompt_chunks[0][0])
493
+
494
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
495
+ input_ids.extend(x[offset:])
496
+
497
+ if return_tensors is not None:
498
+ if return_tensors == 'pt':
499
+ return torch.tensor(input_ids, dtype=torch.long)
500
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
501
+ return input_ids
502
+
503
+
504
+ def get_model_name_from_path(model_path):
505
+ model_path = model_path.strip("/")
506
+ model_paths = model_path.split("/")
507
+ if model_paths[-1].startswith('checkpoint-'):
508
+ return model_paths[-2] + "_" + model_paths[-1]
509
+ else:
510
+ return model_paths[-1]
511
+
512
+
513
+ class KeywordsStoppingCriteria(StoppingCriteria):
514
+ def __init__(self, keywords, tokenizer, input_ids):
515
+ self.keywords = keywords
516
+ self.keyword_ids = []
517
+ self.max_keyword_len = 0
518
+ for keyword in keywords:
519
+ cur_keyword_ids = tokenizer(keyword).input_ids
520
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
521
+ cur_keyword_ids = cur_keyword_ids[1:]
522
+ if len(cur_keyword_ids) > self.max_keyword_len:
523
+ self.max_keyword_len = len(cur_keyword_ids)
524
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
525
+ self.tokenizer = tokenizer
526
+ self.start_len = input_ids.shape[1]
527
+
528
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
529
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
530
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
531
+ for keyword_id in self.keyword_ids:
532
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
533
+ return True
534
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
535
+ for keyword in self.keywords:
536
+ if keyword in outputs:
537
+ return True
538
+ return False
539
+
540
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
541
+ outputs = []
542
+ for i in range(output_ids.shape[0]):
543
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
544
+ return all(outputs)
545
+
546
+
547
+
548
+ def load_image(image_file):
549
+ if image_file.startswith("http") or image_file.startswith("https"):
550
+ response = requests.get(image_file)
551
+ image = Image.open(BytesIO(response.content)).convert("RGB")
552
+ else:
553
+ image = Image.open(image_file).convert("RGB")
554
+ return image
555
+
556
+
557
+ def generate(
558
+ prompt: str,
559
+ model: str,
560
+ tokenizer = None,
561
+ image: str = None,
562
+ device: str = None,
563
+ max_new_tokens: int = 1024,
564
+ num_beams = 1,
565
+ top_p=None,
566
+ temperature=0.2
567
+ ):
568
+ if not device:
569
+ if torch.cuda.is_available() and torch.cuda.device_count():
570
+ device = "cuda:0"
571
+ logging.warning(
572
+ 'inference device is not set, using cuda:0, %s',
573
+ torch.cuda.get_device_name(0)
574
+ )
575
+ else:
576
+ device = 'cpu'
577
+ logging.warning(
578
+ (
579
+ 'No CUDA device detected, using cpu, '
580
+ 'expect slower speeds.'
581
+ )
582
+ )
583
+
584
+ if 'cuda' in device and not torch.cuda.is_available():
585
+ raise ValueError('CUDA device requested but no CUDA device detected.')
586
+
587
+ if isinstance(model, str):
588
+ checkpoint_path = model
589
+ # print(f'loading model from {checkpoint_path}...')
590
+ model = AutoModelForCausalLM.from_pretrained(
591
+ checkpoint_path,
592
+ trust_remote_code=True
593
+ )
594
+ # print('model load over')
595
+ config = model.config
596
+ if tokenizer is None:
597
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, model_max_length = config.tokenizer_model_max_length,
598
+ padding_side = config.tokenizer_padding_side)
599
+ image_processor = model.vision_tower._image_processor
600
+ context_len = getattr(config, 'max_sequence_length', 2048)
601
+ model.to(device).eval()
602
+
603
+
604
+ if image is not None:
605
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
606
+ conv = conv_phi_v0.copy()
607
+ conv.append_message(conv.roles[0], prompt)
608
+ conv.append_message(conv.roles[1], None)
609
+ prompt = conv.get_prompt()
610
+ if image is not None:
611
+ # print('loading image...')
612
+ image = load_image(image)
613
+ # print('load image over')
614
+ image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16)
615
+
616
+ input_ids = (
617
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
618
+ .unsqueeze(0)
619
+ .to(model.device, dtype=torch.float16)
620
+ )
621
+ # Generate
622
+ stime = time.time()
623
+ # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
624
+ # keywords = [stop_str]
625
+ # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
626
+ # print('start inference...')
627
+ with torch.inference_mode():
628
+ output_ids = model.generate(
629
+ input_ids,
630
+ images=image_tensor,
631
+ do_sample=True if temperature > 0 else False,
632
+ temperature=temperature,
633
+ top_p=top_p,
634
+ num_beams=num_beams,
635
+ pad_token_id=tokenizer.pad_token_id,
636
+ max_new_tokens=max_new_tokens,
637
+ use_cache=True,
638
+ # stopping_criteria=[stopping_criteria],
639
+ )
640
+
641
+ # print('inference over')
642
+ generation_time = time.time() - stime
643
+ outputs = tokenizer.batch_decode(
644
+ output_ids, skip_special_tokens=True
645
+ )[0]
646
+ # outputs = outputs.strip()
647
+ # if outputs.endswith(stop_str):
648
+ # outputs = outputs[: -len(stop_str)]
649
+ outputs = outputs.strip()
650
+
651
+ return outputs, generation_time
652
+ def tinyllava_elm_generate_parser():
653
+ """Argument Parser"""
654
+
655
+ class KwargsParser(argparse.Action):
656
+ """Parser action class to parse kwargs of form key=value"""
657
+ def __call__(self, parser, namespace, values, option_string=None):
658
+ setattr(namespace, self.dest, dict())
659
+ for val in values:
660
+ if '=' not in val:
661
+ raise ValueError(
662
+ (
663
+ 'Argument parsing error, kwargs are expected in'
664
+ ' the form of key=value.'
665
+ )
666
+ )
667
+ kwarg_k, kwarg_v = val.split('=')
668
+ try:
669
+ converted_v = int(kwarg_v)
670
+ except ValueError:
671
+ try:
672
+ converted_v = float(kwarg_v)
673
+ except ValueError:
674
+ converted_v = kwarg_v
675
+ getattr(namespace, self.dest)[kwarg_k] = converted_v
676
+
677
+ parser = argparse.ArgumentParser('TinyLLaVA-OpenELM Generate Module')
678
+ parser.add_argument(
679
+ '--model',
680
+ dest='model',
681
+ help='Path to the hf converted model.',
682
+ required=True,
683
+ type=str,
684
+ )
685
+ parser.add_argument(
686
+ '--prompt',
687
+ dest='prompt',
688
+ help='Prompt for LLM call.',
689
+ default='',
690
+ type=str,
691
+ )
692
+ parser.add_argument(
693
+ '--device',
694
+ dest='device',
695
+ help='Device used for inference.',
696
+ type=str,
697
+ )
698
+ parser.add_argument("--image", type=str, default=None)
699
+ parser.add_argument("--temperature", type=float, default=0)
700
+ parser.add_argument("--top_p", type=float, default=None)
701
+ parser.add_argument("--num_beams", type=int, default=1)
702
+ parser.add_argument("--max_new_tokens", type=int, default=512)
703
+ return parser.parse_args()
704
+
705
+
706
+ if __name__ == '__main__':
707
+ args = tinyllava_elm_generate_parser()
708
+
709
+ output_text, genertaion_time = generate(
710
+ prompt=args.prompt,
711
+ image=args.image,
712
+ model=args.model,
713
+ device=args.device,
714
+ max_new_tokens = args.max_new_tokens,
715
+ num_beams = args.num_beams,
716
+ top_p=args.top_p,
717
+ temperature=args.temperature
718
+ )
719
+
720
+ print_txt = (
721
+ f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
722
+ '\033[1m Prompt + Generated Output\033[0m\r\n'
723
+ f'{"-" * os.get_terminal_size().columns}\r\n'
724
+ f'{output_text}\r\n'
725
+ f'{"-" * os.get_terminal_size().columns}\r\n'
726
+ '\r\nGeneration took'
727
+ f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
728
+ 'seconds.\r\n'
729
+ )
730
+ print(print_txt)