tinyllava commited on
Commit
66f6d6a
1 Parent(s): b4ca6f9

Upload python file

Browse files
Files changed (2) hide show
  1. data_preprocess.py +543 -0
  2. modeling_tinyllava_phi.py +55 -0
data_preprocess.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import requests
3
+ from PIL import Image
4
+ import torch
5
+ from io import BytesIO
6
+ import base64
7
+ import time
8
+ import torch
9
+ from transformers import StoppingCriteria
10
+
11
+ import math
12
+ import ast
13
+
14
+ # Model Constants
15
+ IGNORE_INDEX = -100
16
+ IMAGE_TOKEN_INDEX = -200
17
+ DEFAULT_IMAGE_TOKEN = "<image>"
18
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
+ DEFAULT_IM_START_TOKEN = "<im_start>"
20
+ DEFAULT_IM_END_TOKEN = "<im_end>"
21
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
22
+ import dataclasses
23
+ from enum import auto, Enum
24
+ from typing import List, Tuple
25
+
26
+
27
+ class SeparatorStyle(Enum):
28
+ """Different separator style."""
29
+ SINGLE = auto()
30
+ TWO = auto()
31
+ MPT = auto()
32
+ PLAIN = auto()
33
+ LLAMA_2 = auto()
34
+ TINY_LLAMA = auto()
35
+ QWEN_2 = auto()
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class Conversation:
40
+ """A class that keeps all conversation history."""
41
+ system: str
42
+ roles: List[str]
43
+ messages: List[List[str]]
44
+ offset: int
45
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
46
+ sep: str = "###"
47
+ sep2: str = None
48
+ version: str = "Unknown"
49
+
50
+ skip_next: bool = False
51
+
52
+ def get_prompt(self):
53
+ messages = self.messages
54
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
55
+ messages = self.messages.copy()
56
+ init_role, init_msg = messages[0].copy()
57
+ init_msg = init_msg[0].replace("<image>", "").strip()
58
+ if 'mmtag' in self.version:
59
+ messages[0] = (init_role, init_msg)
60
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
61
+ messages.insert(1, (self.roles[1], "Received."))
62
+ else:
63
+ messages[0] = (init_role, "<image>\n" + init_msg)
64
+
65
+ if self.sep_style == SeparatorStyle.SINGLE:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + ": " + message + self.sep
72
+ else:
73
+ ret += role + ":"
74
+ elif self.sep_style == SeparatorStyle.TWO:
75
+ seps = [self.sep, self.sep2]
76
+ ret = self.system + seps[0]
77
+ for i, (role, message) in enumerate(messages):
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + ": " + message + seps[i % 2]
82
+ else:
83
+ ret += role + ":"
84
+ elif self.sep_style == SeparatorStyle.MPT:
85
+ ret = self.system + self.sep
86
+ for role, message in messages:
87
+ if message:
88
+ if type(message) is tuple:
89
+ message, _, _ = message
90
+ ret += role + message + self.sep
91
+ else:
92
+ ret += role
93
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
94
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
95
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
96
+ ret = ""
97
+
98
+ for i, (role, message) in enumerate(messages):
99
+ if i == 0:
100
+ assert message, "first message should not be none"
101
+ assert role == self.roles[0], "first message should come from user"
102
+ if message:
103
+ if type(message) is tuple:
104
+ message, _, _ = message
105
+ if i == 0: message = wrap_sys(self.system) + message
106
+ if i % 2 == 0:
107
+ message = wrap_inst(message)
108
+ ret += self.sep + message
109
+ else:
110
+ ret += " " + message + " " + self.sep2
111
+ else:
112
+ ret += ""
113
+ ret = ret.lstrip(self.sep)
114
+ elif self.sep_style == SeparatorStyle.TINY_LLAMA:
115
+ sep = "</s>"
116
+ wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
117
+ wrap_user = lambda msg: f"<|user|>\n{msg}\n"
118
+ wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
119
+ ret = ""
120
+
121
+ for i, (role, message) in enumerate(messages):
122
+ if i == 0:
123
+ assert message, "first message should not be none"
124
+ assert role == self.roles[0], "first message should come from user"
125
+ if message:
126
+ if type(message) is tuple:
127
+ message, _, _ = message
128
+ if i % 2 == 0:
129
+ message = wrap_user(message)
130
+ if i == 0:
131
+ message = wrap_sys(self.system) + message
132
+ ret += self.sep + message
133
+ else:
134
+ message = wrap_assistant(message) + self.sep2
135
+ ret += message
136
+ else:
137
+ ret += "<|assistant|>\n"
138
+ ret = ret.lstrip(self.sep)
139
+ elif self.sep_style == SeparatorStyle.QWEN_2:
140
+ ret = self.system + self.sep
141
+ for role, message in messages:
142
+ if message:
143
+ if type(message) is tuple:
144
+ message, _, _ = message
145
+ ret += role + message + self.sep
146
+ else:
147
+ ret += role
148
+ elif self.sep_style == SeparatorStyle.PLAIN:
149
+ seps = [self.sep, self.sep2]
150
+ ret = self.system
151
+ for i, (role, message) in enumerate(messages):
152
+ if message:
153
+ if type(message) is tuple:
154
+ message, _, _ = message
155
+ ret += message + seps[i % 2]
156
+ else:
157
+ ret += ""
158
+ else:
159
+ raise ValueError(f"Invalid style: {self.sep_style}")
160
+
161
+ return ret
162
+
163
+ def append_message(self, role, message):
164
+ self.messages.append([role, message])
165
+
166
+ def get_images(self, return_pil=False):
167
+ images = []
168
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
169
+ if i % 2 == 0:
170
+ if type(msg) is tuple:
171
+ import base64
172
+ from io import BytesIO
173
+ from PIL import Image
174
+ msg, image, image_process_mode = msg
175
+ if image_process_mode == "Pad":
176
+ def expand2square(pil_img, background_color=(122, 116, 104)):
177
+ width, height = pil_img.size
178
+ if width == height:
179
+ return pil_img
180
+ elif width > height:
181
+ result = Image.new(pil_img.mode, (width, width), background_color)
182
+ result.paste(pil_img, (0, (width - height) // 2))
183
+ return result
184
+ else:
185
+ result = Image.new(pil_img.mode, (height, height), background_color)
186
+ result.paste(pil_img, ((height - width) // 2, 0))
187
+ return result
188
+ image = expand2square(image)
189
+ elif image_process_mode in ["Default", "Crop"]:
190
+ pass
191
+ elif image_process_mode == "Resize":
192
+ image = image.resize((336, 336))
193
+ else:
194
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
195
+ max_hw, min_hw = max(image.size), min(image.size)
196
+ aspect_ratio = max_hw / min_hw
197
+ max_len, min_len = 800, 400
198
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
199
+ longest_edge = int(shortest_edge * aspect_ratio)
200
+ W, H = image.size
201
+ if longest_edge != max(image.size):
202
+ if H > W:
203
+ H, W = longest_edge, shortest_edge
204
+ else:
205
+ H, W = shortest_edge, longest_edge
206
+ image = image.resize((W, H))
207
+ if return_pil:
208
+ images.append(image)
209
+ else:
210
+ buffered = BytesIO()
211
+ image.save(buffered, format="PNG")
212
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
213
+ images.append(img_b64_str)
214
+ return images
215
+
216
+ def to_gradio_chatbot(self):
217
+ ret = []
218
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
219
+ if i % 2 == 0:
220
+ if type(msg) is tuple:
221
+ import base64
222
+ from io import BytesIO
223
+ msg, image, image_process_mode = msg
224
+ max_hw, min_hw = max(image.size), min(image.size)
225
+ aspect_ratio = max_hw / min_hw
226
+ max_len, min_len = 800, 400
227
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
228
+ longest_edge = int(shortest_edge * aspect_ratio)
229
+ W, H = image.size
230
+ if H > W:
231
+ H, W = longest_edge, shortest_edge
232
+ else:
233
+ H, W = shortest_edge, longest_edge
234
+ image = image.resize((W, H))
235
+ buffered = BytesIO()
236
+ image.save(buffered, format="JPEG")
237
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
238
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
239
+ msg = img_str + msg.replace('<image>', '').strip()
240
+ ret.append([msg, None])
241
+ else:
242
+ ret.append([msg, None])
243
+ else:
244
+ ret[-1][-1] = msg
245
+ return ret
246
+
247
+ def copy(self):
248
+ return Conversation(
249
+ system=self.system,
250
+ roles=self.roles,
251
+ messages=[[x, y] for x, y in self.messages],
252
+ offset=self.offset,
253
+ sep_style=self.sep_style,
254
+ sep=self.sep,
255
+ sep2=self.sep2,
256
+ version=self.version)
257
+
258
+ def dict(self):
259
+ if len(self.get_images()) > 0:
260
+ return {
261
+ "system": self.system,
262
+ "roles": self.roles,
263
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
264
+ "offset": self.offset,
265
+ "sep": self.sep,
266
+ "sep2": self.sep2,
267
+ }
268
+ return {
269
+ "system": self.system,
270
+ "roles": self.roles,
271
+ "messages": self.messages,
272
+ "offset": self.offset,
273
+ "sep": self.sep,
274
+ "sep2": self.sep2,
275
+ }
276
+
277
+
278
+
279
+
280
+ conv_phi_v0 = Conversation(
281
+ system="A chat between a curious user and an artificial intelligence assistant. "
282
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
283
+ roles=("USER", "ASSISTANT"),
284
+ version="phi",
285
+ messages=(),
286
+ offset=0,
287
+ sep_style=SeparatorStyle.TWO,
288
+ sep=" ",
289
+ sep2="<|endoftext|>",
290
+ )
291
+
292
+
293
+
294
+ def select_best_resolution(original_size, possible_resolutions):
295
+ """
296
+ Selects the best resolution from a list of possible resolutions based on the original size.
297
+
298
+ Args:
299
+ original_size (tuple): The original size of the image in the format (width, height).
300
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
301
+
302
+ Returns:
303
+ tuple: The best fit resolution in the format (width, height).
304
+ """
305
+ original_width, original_height = original_size
306
+ best_fit = None
307
+ max_effective_resolution = 0
308
+ min_wasted_resolution = float('inf')
309
+
310
+ for width, height in possible_resolutions:
311
+ scale = min(width / original_width, height / original_height)
312
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
313
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
314
+ wasted_resolution = (width * height) - effective_resolution
315
+
316
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
317
+ max_effective_resolution = effective_resolution
318
+ min_wasted_resolution = wasted_resolution
319
+ best_fit = (width, height)
320
+
321
+ return best_fit
322
+
323
+
324
+ ## added by llava-1.6
325
+ def resize_and_pad_image(image, target_resolution):
326
+ """
327
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
328
+
329
+ Args:
330
+ image (PIL.Image.Image): The input image.
331
+ target_resolution (tuple): The target resolution (width, height) of the image.
332
+
333
+ Returns:
334
+ PIL.Image.Image: The resized and padded image.
335
+ """
336
+ original_width, original_height = image.size
337
+ target_width, target_height = target_resolution
338
+
339
+ scale_w = target_width / original_width
340
+ scale_h = target_height / original_height
341
+
342
+ if scale_w < scale_h:
343
+ new_width = target_width
344
+ new_height = min(math.ceil(original_height * scale_w), target_height)
345
+ else:
346
+ new_height = target_height
347
+ new_width = min(math.ceil(original_width * scale_h), target_width)
348
+
349
+ # Resize the image
350
+ resized_image = image.resize((new_width, new_height))
351
+
352
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
353
+ paste_x = (target_width - new_width) // 2
354
+ paste_y = (target_height - new_height) // 2
355
+ new_image.paste(resized_image, (paste_x, paste_y))
356
+
357
+ return new_image
358
+
359
+
360
+ ## added by llava-1.6
361
+ def divide_to_patches(image, patch_size):
362
+ """
363
+ Divides an image into patches of a specified size.
364
+
365
+ Args:
366
+ image (PIL.Image.Image): The input image.
367
+ patch_size (int): The size of each patch.
368
+
369
+ Returns:
370
+ list: A list of PIL.Image.Image objects representing the patches.
371
+ """
372
+ patches = []
373
+ width, height = image.size
374
+ for i in range(0, height, patch_size):
375
+ for j in range(0, width, patch_size):
376
+ box = (j, i, j + patch_size, i + patch_size)
377
+ patch = image.crop(box)
378
+ patches.append(patch)
379
+
380
+ return patches
381
+
382
+
383
+ ## added by llava-1.6
384
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
385
+ """
386
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
387
+
388
+ Args:
389
+ image_size (tuple): The size of the input image in the format (width, height).
390
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
391
+ patch_size (int): The size of each image patch.
392
+
393
+ Returns:
394
+ tuple: The shape of the image patch grid in the format (width, height).
395
+ """
396
+ if type(grid_pinpoints) is list:
397
+ possible_resolutions = grid_pinpoints
398
+ else:
399
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
400
+ width, height = select_best_resolution(image_size, possible_resolutions)
401
+ return width // patch_size, height // patch_size
402
+
403
+
404
+ ## added by llava-1.6
405
+ def process_anyres_image(image, processor, grid_pinpoints):
406
+ """
407
+ Process an image with variable resolutions.
408
+
409
+ Args:
410
+ image (PIL.Image.Image): The input image to be processed.
411
+ processor: The image processor object.
412
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
413
+
414
+ Returns:
415
+ torch.Tensor: A tensor containing the processed image patches.
416
+ """
417
+ if type(grid_pinpoints) is list:
418
+ possible_resolutions = grid_pinpoints
419
+ else:
420
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
421
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
422
+ image_padded = resize_and_pad_image(image, best_resolution)
423
+
424
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
425
+
426
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
427
+
428
+ image_patches = [image_original_resize] + patches
429
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
430
+ for image_patch in image_patches]
431
+ return torch.stack(image_patches, dim=0)
432
+
433
+
434
+ def load_image_from_base64(image):
435
+ return Image.open(BytesIO(base64.b64decode(image)))
436
+
437
+
438
+ def expand2square(pil_img, background_color):
439
+ width, height = pil_img.size
440
+ if width == height:
441
+ return pil_img
442
+ elif width > height:
443
+ result = Image.new(pil_img.mode, (width, width), background_color)
444
+ result.paste(pil_img, (0, (width - height) // 2))
445
+ return result
446
+ else:
447
+ result = Image.new(pil_img.mode, (height, height), background_color)
448
+ result.paste(pil_img, ((height - width) // 2, 0))
449
+ return result
450
+
451
+
452
+ def process_images(images, image_processor, model_cfg):
453
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
454
+ new_images = []
455
+ if image_aspect_ratio == 'pad':
456
+ for image in images:
457
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
458
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
459
+ new_images.append(image)
460
+ elif image_aspect_ratio == "anyres":
461
+ for image in images:
462
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
463
+ new_images.append(image)
464
+ else:
465
+ return image_processor(images, return_tensors='pt')['pixel_values']
466
+ if all(x.shape == new_images[0].shape for x in new_images):
467
+ new_images = torch.stack(new_images, dim=0)
468
+ return new_images
469
+
470
+
471
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
472
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
473
+
474
+ def insert_separator(X, sep):
475
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
476
+
477
+ input_ids = []
478
+ offset = 0
479
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
480
+ offset = 1
481
+ input_ids.append(prompt_chunks[0][0])
482
+
483
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
484
+ input_ids.extend(x[offset:])
485
+
486
+ if return_tensors is not None:
487
+ if return_tensors == 'pt':
488
+ return torch.tensor(input_ids, dtype=torch.long)
489
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
490
+ return input_ids
491
+
492
+
493
+ def get_model_name_from_path(model_path):
494
+ model_path = model_path.strip("/")
495
+ model_paths = model_path.split("/")
496
+ if model_paths[-1].startswith('checkpoint-'):
497
+ return model_paths[-2] + "_" + model_paths[-1]
498
+ else:
499
+ return model_paths[-1]
500
+
501
+
502
+ class KeywordsStoppingCriteria(StoppingCriteria):
503
+ def __init__(self, keywords, tokenizer, input_ids):
504
+ self.keywords = keywords
505
+ self.keyword_ids = []
506
+ self.max_keyword_len = 0
507
+ for keyword in keywords:
508
+ cur_keyword_ids = tokenizer(keyword).input_ids
509
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
510
+ cur_keyword_ids = cur_keyword_ids[1:]
511
+ if len(cur_keyword_ids) > self.max_keyword_len:
512
+ self.max_keyword_len = len(cur_keyword_ids)
513
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
514
+ self.tokenizer = tokenizer
515
+ self.start_len = input_ids.shape[1]
516
+
517
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
518
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
519
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
520
+ for keyword_id in self.keyword_ids:
521
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
522
+ return True
523
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
524
+ for keyword in self.keywords:
525
+ if keyword in outputs:
526
+ return True
527
+ return False
528
+
529
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
530
+ outputs = []
531
+ for i in range(output_ids.shape[0]):
532
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
533
+ return all(outputs)
534
+
535
+
536
+
537
+ def load_image(image_file):
538
+ if image_file.startswith("http") or image_file.startswith("https"):
539
+ response = requests.get(image_file)
540
+ image = Image.open(BytesIO(response.content)).convert("RGB")
541
+ else:
542
+ image = Image.open(image_file).convert("RGB")
543
+ return image
modeling_tinyllava_phi.py CHANGED
@@ -16,6 +16,7 @@ from transformers import CLIPVisionModel, CLIPImageProcessor, SiglipVisionModel,
16
  from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
17
 
18
  from transformers import AutoConfig, AutoModelForCausalLM, PhiForCausalLM
 
19
 
20
  # from tinyllava.utils.data_utils import get_value_from_kwargs
21
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
@@ -414,6 +415,60 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
414
  position_ids = None
415
 
416
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  AutoConfig.register("tinyllava", TinyLlavaConfig)
419
  AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)
 
16
  from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
17
 
18
  from transformers import AutoConfig, AutoModelForCausalLM, PhiForCausalLM
19
+ from data_preprocess import *
20
 
21
  # from tinyllava.utils.data_utils import get_value_from_kwargs
22
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
 
415
  position_ids = None
416
 
417
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
418
+
419
+ def chat(
420
+ self,
421
+ prompt: str,
422
+ tokenizer = None,
423
+ image: str = None,
424
+ max_new_tokens: int = 512,
425
+ num_beams = 1,
426
+ top_p=None,
427
+ temperature=0
428
+ ):
429
+ image_processor = self.vision_tower._image_processor
430
+
431
+ if image is not None:
432
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
433
+ conv = conv_phi_v0.copy()
434
+ conv.append_message(conv.roles[0], prompt)
435
+ conv.append_message(conv.roles[1], None)
436
+ prompt = conv.get_prompt()
437
+ if image is not None:
438
+ image = load_image(image)
439
+ image_tensor = process_images(image, image_processor, self.config).to(self.device)
440
+
441
+ input_ids = (
442
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
443
+ .unsqueeze(0).to(self.device)
444
+ )
445
+ # Generate
446
+ stime = time.time()
447
+
448
+ with torch.inference_mode():
449
+ output_ids = self.generate(
450
+ input_ids,
451
+ images=image_tensor,
452
+ do_sample=True if temperature > 0 else False,
453
+ temperature=temperature,
454
+ top_p=top_p,
455
+ num_beams=num_beams,
456
+ pad_token_id=tokenizer.pad_token_id,
457
+ max_new_tokens=max_new_tokens,
458
+ use_cache=True,
459
+ # stopping_criteria=[stopping_criteria],
460
+ )
461
+
462
+ # print('inference over')
463
+ generation_time = time.time() - stime
464
+ outputs = tokenizer.batch_decode(
465
+ output_ids, skip_special_tokens=True
466
+ )[0]
467
+
468
+ outputs = outputs.strip()
469
+
470
+ return outputs, generation_time
471
+
472
 
473
  AutoConfig.register("tinyllava", TinyLlavaConfig)
474
  AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)