NGUYEN, Xuan Phi commited on
Commit
357985d
1 Parent(s): 4f071ee
multipurpose_chatbot/engines/image_processing_llava_next.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for LLaVa-NeXT."""
16
+
17
+ import math
18
+ from typing import Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+
22
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
23
+ from transformers.image_transforms import (
24
+ convert_to_rgb,
25
+ get_resize_output_image_size,
26
+ pad,
27
+ resize,
28
+ to_channel_dimension_format,
29
+ )
30
+ from transformers.image_utils import (
31
+ OPENAI_CLIP_MEAN,
32
+ OPENAI_CLIP_STD,
33
+ ChannelDimension,
34
+ ImageInput,
35
+ PILImageResampling,
36
+ get_image_size,
37
+ infer_channel_dimension_format,
38
+ is_scaled_image,
39
+ make_list_of_images,
40
+ to_numpy_array,
41
+ valid_images,
42
+ validate_preprocess_arguments,
43
+ )
44
+ from transformers.utils import TensorType, is_vision_available, logging
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ if is_vision_available():
51
+ from PIL import Image
52
+
53
+
54
+ def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
55
+ """
56
+ Divides an image into patches of a specified size.
57
+
58
+ Args:
59
+ image (`np.array`):
60
+ The input image.
61
+ patch_size (`int`):
62
+ The size of each patch.
63
+ input_data_format (`ChannelDimension` or `str`):
64
+ The channel dimension format of the input image.
65
+
66
+ Returns:
67
+ list: A list of np.array representing the patches.
68
+ """
69
+ patches = []
70
+ height, width = get_image_size(image, channel_dim=input_data_format)
71
+ for i in range(0, height, patch_size):
72
+ for j in range(0, width, patch_size):
73
+ if input_data_format == ChannelDimension.LAST:
74
+ patch = image[i : i + patch_size, j : j + patch_size]
75
+ else:
76
+ patch = image[:, i : i + patch_size, j : j + patch_size]
77
+ patches.append(patch)
78
+
79
+ return patches
80
+
81
+
82
+ def expand_to_square(image: np.array, background_color, input_data_format) -> np.array:
83
+ """
84
+ Expands an image to a square by adding a background color.
85
+ """
86
+
87
+ height, width = get_image_size(image, channel_dim=input_data_format)
88
+ if width == height:
89
+ return image
90
+ elif width > height:
91
+ result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
92
+ result[(width - height) // 2 : (width - height) // 2 + height, :] = image
93
+ return result
94
+ else:
95
+ result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
96
+ result[:, (height - width) // 2 : (height - width) // 2 + width] = image
97
+ return result
98
+
99
+
100
+ def _get_patch_output_size(image, target_resolution, input_data_format):
101
+ original_height, original_width = get_image_size(image, channel_dim=input_data_format)
102
+ target_height, target_width = target_resolution
103
+
104
+ scale_w = target_width / original_width
105
+ scale_h = target_height / original_height
106
+
107
+ if scale_w < scale_h:
108
+ new_width = target_width
109
+ new_height = min(math.ceil(original_height * scale_w), target_height)
110
+ else:
111
+ new_height = target_height
112
+ new_width = min(math.ceil(original_width * scale_h), target_width)
113
+
114
+ return new_height, new_width
115
+
116
+
117
+ class LlavaNextImageProcessor(BaseImageProcessor):
118
+ r"""
119
+ Constructs a LLaVa-NeXT image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
120
+ for processing high resolution images as explained in the [LLaVa paper](https://arxiv.org/abs/2310.03744).
121
+
122
+ Args:
123
+ do_resize (`bool`, *optional*, defaults to `True`):
124
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
125
+ `do_resize` in the `preprocess` method.
126
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
127
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
128
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
129
+ method.
130
+ image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`):
131
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
132
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
133
+ method.
134
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
135
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
136
+ do_center_crop (`bool`, *optional*, defaults to `True`):
137
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
138
+ `preprocess` method.
139
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
140
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
141
+ method.
142
+ do_rescale (`bool`, *optional*, defaults to `True`):
143
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
144
+ the `preprocess` method.
145
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
146
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
147
+ method.
148
+ do_normalize (`bool`, *optional*, defaults to `True`):
149
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
150
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
151
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
152
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
153
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
154
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
155
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
156
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
157
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
158
+ Whether to convert the image to RGB.
159
+ """
160
+
161
+ model_input_names = ["pixel_values"]
162
+
163
+ def __init__(
164
+ self,
165
+ do_resize: bool = True,
166
+ size: Dict[str, int] = None,
167
+ image_grid_pinpoints: List = None,
168
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
169
+ do_center_crop: bool = True,
170
+ crop_size: Dict[str, int] = None,
171
+ do_rescale: bool = True,
172
+ rescale_factor: Union[int, float] = 1 / 255,
173
+ do_normalize: bool = True,
174
+ image_mean: Optional[Union[float, List[float]]] = None,
175
+ image_std: Optional[Union[float, List[float]]] = None,
176
+ do_convert_rgb: bool = True,
177
+ **kwargs,
178
+ ) -> None:
179
+ super().__init__(**kwargs)
180
+ size = size if size is not None else {"shortest_edge": 224}
181
+ size = get_size_dict(size, default_to_square=False)
182
+ image_grid_pinpoints = (
183
+ image_grid_pinpoints
184
+ if image_grid_pinpoints is not None
185
+ else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
186
+ )
187
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
188
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
189
+
190
+ self.do_resize = do_resize
191
+ self.size = size
192
+ self.image_grid_pinpoints = image_grid_pinpoints
193
+ self.resample = resample
194
+ self.do_center_crop = do_center_crop
195
+ self.crop_size = crop_size
196
+ self.do_rescale = do_rescale
197
+ self.rescale_factor = rescale_factor
198
+ self.do_normalize = do_normalize
199
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
200
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
201
+ self.do_convert_rgb = do_convert_rgb
202
+
203
+ # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize with CLIP->LLaVa
204
+ def resize(
205
+ self,
206
+ image: np.ndarray,
207
+ size: Dict[str, int],
208
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
209
+ data_format: Optional[Union[str, ChannelDimension]] = None,
210
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
211
+ **kwargs,
212
+ ) -> np.ndarray:
213
+ """
214
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
215
+ resized to keep the input aspect ratio.
216
+
217
+ Args:
218
+ image (`np.ndarray`):
219
+ Image to resize.
220
+ size (`Dict[str, int]`):
221
+ Size of the output image.
222
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
223
+ Resampling filter to use when resiizing the image.
224
+ data_format (`str` or `ChannelDimension`, *optional*):
225
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
226
+ input_data_format (`ChannelDimension` or `str`, *optional*):
227
+ The channel dimension format of the input image. If not provided, it will be inferred.
228
+ """
229
+ default_to_square = True
230
+ if "shortest_edge" in size:
231
+ size = size["shortest_edge"]
232
+ default_to_square = False
233
+ elif "height" in size and "width" in size:
234
+ size = (size["height"], size["width"])
235
+ else:
236
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
237
+
238
+ output_size = get_resize_output_image_size(
239
+ image,
240
+ size=size,
241
+ default_to_square=default_to_square,
242
+ input_data_format=input_data_format,
243
+ )
244
+
245
+ return resize(
246
+ image,
247
+ size=output_size,
248
+ resample=resample,
249
+ data_format=data_format,
250
+ input_data_format=input_data_format,
251
+ **kwargs,
252
+ )
253
+
254
+ def _preprocess(
255
+ self,
256
+ images: ImageInput,
257
+ do_resize: bool = None,
258
+ size: Dict[str, int] = None,
259
+ resample: PILImageResampling = None,
260
+ do_center_crop: bool = None,
261
+ crop_size: int = None,
262
+ do_rescale: bool = None,
263
+ rescale_factor: float = None,
264
+ do_normalize: bool = None,
265
+ image_mean: Optional[Union[float, List[float]]] = None,
266
+ image_std: Optional[Union[float, List[float]]] = None,
267
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
268
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
269
+ ) -> Image.Image:
270
+ """
271
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
272
+
273
+ Args:
274
+ images (`ImageInput`):
275
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
276
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
277
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
278
+ Whether to resize the image.
279
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
280
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
281
+ the longest edge resized to keep the input aspect ratio.
282
+ resample (`int`, *optional*, defaults to `self.resample`):
283
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
284
+ has an effect if `do_resize` is set to `True`.
285
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
286
+ Whether to center crop the image.
287
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
288
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
289
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
290
+ Whether to rescale the image.
291
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
292
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
293
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
294
+ Whether to normalize the image.
295
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
296
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
297
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
298
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
299
+ `True`.
300
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
301
+ The channel dimension format for the output image. Can be one of:
302
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
303
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
304
+ - Unset: Use the channel dimension format of the input image.
305
+ input_data_format (`ChannelDimension` or `str`, *optional*):
306
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
307
+ from the input image. Can be one of:
308
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
309
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
310
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
311
+ """
312
+ images = make_list_of_images(images)
313
+
314
+ if do_resize:
315
+ images = [
316
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
317
+ for image in images
318
+ ]
319
+
320
+ if do_center_crop:
321
+ images = [
322
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
323
+ ]
324
+
325
+ if do_rescale:
326
+ images = [
327
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
328
+ for image in images
329
+ ]
330
+
331
+ if do_normalize:
332
+ images = [
333
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
334
+ for image in images
335
+ ]
336
+
337
+ images = [
338
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
339
+ ]
340
+
341
+ return images
342
+
343
+ def _resize_for_patching(
344
+ self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension
345
+ ) -> np.array:
346
+ """
347
+ Resizes an image to a target resolution while maintaining aspect ratio.
348
+
349
+ Args:
350
+ image (np.array):
351
+ The input image.
352
+ target_resolution (tuple):
353
+ The target resolution (height, width) of the image.
354
+ resample (`PILImageResampling`):
355
+ Resampling filter to use if resizing the image.
356
+ input_data_format (`ChannelDimension` or `str`):
357
+ The channel dimension format of the input image.
358
+
359
+ Returns:
360
+ np.array: The resized and padded image.
361
+ """
362
+ new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
363
+
364
+ # Resize the image
365
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
366
+
367
+ return resized_image
368
+
369
+ def _pad_for_patching(
370
+ self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
371
+ ) -> np.array:
372
+ """
373
+ Pad an image to a target resolution while maintaining aspect ratio.
374
+ """
375
+ target_height, target_width = target_resolution
376
+ new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
377
+
378
+ paste_x = (target_width - new_width) // 2
379
+ paste_y = (target_height - new_height) // 2
380
+
381
+ padded_image = pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
382
+
383
+ return padded_image
384
+
385
+ def get_image_patches(
386
+ self,
387
+ image: np.array,
388
+ grid_pinpoints,
389
+ size: tuple,
390
+ patch_size: int,
391
+ resample: PILImageResampling,
392
+ data_format: ChannelDimension,
393
+ input_data_format: ChannelDimension,
394
+ ) -> List[np.array]:
395
+ """
396
+ Process an image with variable resolutions by dividing it into patches.
397
+
398
+ Args:
399
+ image (np.array):
400
+ The input image to be processed.
401
+ grid_pinpoints (List):
402
+ A string representation of a list of possible resolutions.
403
+ size (`tuple`):
404
+ Size to resize the original image to.
405
+ patch_size (`int`):
406
+ Size of the patches to divide the image into.
407
+ resample (`PILImageResampling`):
408
+ Resampling filter to use if resizing the image.
409
+ data_format (`ChannelDimension` or `str`):
410
+ The channel dimension format for the output image.
411
+ input_data_format (`ChannelDimension` or `str`):
412
+ The channel dimension format of the input image.
413
+
414
+ Returns:
415
+ List[np.array]: A list of NumPy arrays containing the processed image patches.
416
+ """
417
+ if not isinstance(grid_pinpoints, list):
418
+ raise ValueError("grid_pinpoints must be a list of possible resolutions.")
419
+
420
+ possible_resolutions = grid_pinpoints
421
+
422
+ image_size = get_image_size(image, channel_dim=input_data_format)
423
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
424
+ resized_image = self._resize_for_patching(
425
+ image, best_resolution, resample=resample, input_data_format=input_data_format
426
+ )
427
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
428
+
429
+ patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
430
+
431
+ # make sure that all patches are in the input data format
432
+ patches = [
433
+ to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
434
+ for patch in patches
435
+ ]
436
+
437
+ resized_original_image = resize(
438
+ image,
439
+ size=size,
440
+ resample=resample,
441
+ data_format=data_format,
442
+ input_data_format=input_data_format,
443
+ )
444
+
445
+ image_patches = [resized_original_image] + patches
446
+
447
+ return image_patches
448
+
449
+ def preprocess(
450
+ self,
451
+ images: ImageInput,
452
+ do_resize: bool = None,
453
+ size: Dict[str, int] = None,
454
+ image_grid_pinpoints: List = None,
455
+ resample: PILImageResampling = None,
456
+ do_center_crop: bool = None,
457
+ crop_size: int = None,
458
+ do_rescale: bool = None,
459
+ rescale_factor: float = None,
460
+ do_normalize: bool = None,
461
+ image_mean: Optional[Union[float, List[float]]] = None,
462
+ image_std: Optional[Union[float, List[float]]] = None,
463
+ do_convert_rgb: bool = None,
464
+ return_tensors: Optional[Union[str, TensorType]] = None,
465
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
466
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
467
+ concat_images: bool = True,
468
+ ):
469
+ """
470
+ Args:
471
+ images (`ImageInput`):
472
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
473
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
474
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
475
+ Whether to resize the image.
476
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
477
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
478
+ the longest edge resized to keep the input aspect ratio.
479
+ image_grid_pinpoints (`List` *optional*, defaults to `self.image_grid_pinpoints`):
480
+ A list of possible resolutions to use for processing high resolution images. The best resolution is
481
+ selected based on the original size of the image.
482
+ resample (`int`, *optional*, defaults to `self.resample`):
483
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
484
+ has an effect if `do_resize` is set to `True`.
485
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
486
+ Whether to center crop the image.
487
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
488
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
489
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
490
+ Whether to rescale the image.
491
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
492
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
493
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
494
+ Whether to normalize the image.
495
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
496
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
497
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
498
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
499
+ `True`.
500
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
501
+ Whether to convert the image to RGB.
502
+ return_tensors (`str` or `TensorType`, *optional*):
503
+ The type of tensors to return. Can be one of:
504
+ - Unset: Return a list of `np.ndarray`.
505
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
506
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
507
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
508
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
509
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
510
+ The channel dimension format for the output image. Can be one of:
511
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
512
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
513
+ - Unset: Use the channel dimension format of the input image.
514
+ input_data_format (`ChannelDimension` or `str`, *optional*):
515
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
516
+ from the input image. Can be one of:
517
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
518
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
519
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
520
+ """
521
+ do_resize = do_resize if do_resize is not None else self.do_resize
522
+ size = size if size is not None else self.size
523
+ size = get_size_dict(size, param_name="size", default_to_square=False)
524
+ image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints
525
+ resample = resample if resample is not None else self.resample
526
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
527
+ crop_size = crop_size if crop_size is not None else self.crop_size
528
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
529
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
530
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
531
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
532
+ image_mean = image_mean if image_mean is not None else self.image_mean
533
+ image_std = image_std if image_std is not None else self.image_std
534
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
535
+
536
+ images = make_list_of_images(images)
537
+
538
+ if not valid_images(images):
539
+ raise ValueError(
540
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
541
+ "torch.Tensor, tf.Tensor or jax.ndarray."
542
+ )
543
+
544
+ validate_preprocess_arguments(
545
+ do_rescale=do_rescale,
546
+ rescale_factor=rescale_factor,
547
+ do_normalize=do_normalize,
548
+ image_mean=image_mean,
549
+ image_std=image_std,
550
+ do_center_crop=do_center_crop,
551
+ crop_size=crop_size,
552
+ do_resize=do_resize,
553
+ size=size,
554
+ resample=resample,
555
+ )
556
+
557
+ if do_convert_rgb:
558
+ images = [convert_to_rgb(image) for image in images]
559
+
560
+ # All transformations expect numpy arrays.
561
+ images = [to_numpy_array(image) for image in images]
562
+
563
+ if is_scaled_image(images[0]) and do_rescale:
564
+ logger.warning_once(
565
+ "It looks like you are trying to rescale already rescaled images. If the input"
566
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
567
+ )
568
+
569
+ if input_data_format is None:
570
+ # We assume that all images have the same channel dimension format.
571
+ input_data_format = infer_channel_dimension_format(images[0])
572
+
573
+ new_images = []
574
+ image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
575
+ for image in images:
576
+ # convert image into a list of patches
577
+ # we intentially use the same data format as the input data format
578
+ image_patches = self.get_image_patches(
579
+ image,
580
+ image_grid_pinpoints,
581
+ size=(size["shortest_edge"], size["shortest_edge"]),
582
+ patch_size=crop_size["height"],
583
+ resample=resample,
584
+ data_format=input_data_format,
585
+ input_data_format=input_data_format,
586
+ )
587
+
588
+ # preprocess patches
589
+ pixel_values = self._preprocess(
590
+ image_patches,
591
+ do_resize=do_resize,
592
+ size=size,
593
+ resample=resample,
594
+ do_center_crop=do_center_crop,
595
+ crop_size=crop_size,
596
+ do_rescale=do_rescale,
597
+ rescale_factor=rescale_factor,
598
+ do_normalize=do_normalize,
599
+ image_mean=image_mean,
600
+ image_std=image_std,
601
+ data_format=data_format,
602
+ input_data_format=input_data_format,
603
+ )
604
+ pixel_values = np.array(pixel_values)
605
+ new_images.append(pixel_values)
606
+
607
+ if concat_images:
608
+ # image_num_patches = [len(x) for x in new_images]
609
+ pixel_values = np.concatenate(new_images, axis=0)
610
+ data = {
611
+ "pixel_values": pixel_values,
612
+ "image_sizes": image_sizes,
613
+ # "image_num_patches": image_num_patches,
614
+ }
615
+ else:
616
+
617
+ data = {"pixel_values": new_images, "image_sizes": image_sizes}
618
+
619
+ return BatchFeature(data=data, tensor_type=return_tensors)
multipurpose_chatbot/engines/modeling_sealava16.py ADDED
@@ -0,0 +1,1022 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Llava-NeXT model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ import numpy as np
24
+
25
+ from transformers import PreTrainedModel
26
+ from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache
28
+ from transformers.image_processing_utils import select_best_resolution
29
+ from transformers.modeling_outputs import ModelOutput
30
+ from transformers.configuration_utils import PretrainedConfig
31
+ from transformers.utils import (
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ replace_return_docstrings,
36
+ )
37
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
38
+ # from .configuration_llava_next import LlavaNextConfig
39
+ from transformers.models.auto import CONFIG_MAPPING
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class LlavaNextConfig(PretrainedConfig):
45
+ r"""
46
+ This is the configuration class to store the configuration of a [`LlavaNextForConditionalGeneration`]. It is used to instantiate an
47
+ Llava-NeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
48
+ with the defaults will yield a similar configuration to that of the [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
49
+ model.
50
+
51
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
52
+ documentation from [`PretrainedConfig`] for more information.
53
+
54
+ Args:
55
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
56
+ The config object or dictionary of the vision backbone.
57
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
58
+ The config object or dictionary of the text backbone.
59
+ ignore_index (`int`, *optional*, defaults to -100):
60
+ The ignore index for the loss function.
61
+ image_token_index (`int`, *optional*, defaults to 32000):
62
+ The image token index to encode the image prompt.
63
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
64
+ The activation function used by the multimodal projector.
65
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
66
+ The feature selection strategy used to select the vision feature from the vision backbone.
67
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
68
+ If `"full"`, the full vision features are used.
69
+ vision_feature_layer (`int`, *optional*, defaults to -2):
70
+ The index of the layer to select the vision feature.
71
+ image_grid_pinpoints (`List`, *optional*, defaults to `[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]`):
72
+ A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
73
+ of the form `(height, width)`.
74
+
75
+ Example:
76
+
77
+ ```python
78
+ >>> from transformers import LlavaNextForConditionalGeneration, LlavaNextConfig, CLIPVisionConfig, LlamaConfig
79
+
80
+ >>> # Initializing a CLIP-vision config
81
+ >>> vision_config = CLIPVisionConfig()
82
+
83
+ >>> # Initializing a Llama config
84
+ >>> text_config = LlamaConfig()
85
+
86
+ >>> # Initializing a Llava-Next llava-hf/llava-v1.6-mistral-7b-hf style configuration
87
+ >>> configuration = LlavaNextConfig(vision_config, text_config)
88
+
89
+ >>> # Initializing a model from the llava-hf/llava-v1.6-mistral-7b-hf style configuration
90
+ >>> model = LlavaNextForConditionalGeneration(configuration)
91
+
92
+ >>> # Accessing the model configuration
93
+ >>> configuration = model.config
94
+ ```"""
95
+
96
+ model_type = "llava_next"
97
+ is_composition = False
98
+
99
+ def __init__(
100
+ self,
101
+ vision_config=None,
102
+ text_config=None,
103
+ ignore_index=-100,
104
+ image_token_index=32000,
105
+ projector_hidden_act="gelu",
106
+ vision_feature_select_strategy="default",
107
+ vision_feature_layer=-2,
108
+ image_grid_pinpoints=None,
109
+ **kwargs,
110
+ ):
111
+ self.ignore_index = ignore_index
112
+ self.image_token_index = image_token_index
113
+ self.projector_hidden_act = projector_hidden_act
114
+
115
+ if vision_feature_select_strategy not in ["default", "full"]:
116
+ raise ValueError(
117
+ "vision_feature_select_strategy should be one of 'default', 'full'."
118
+ f"Got: {vision_feature_select_strategy}"
119
+ )
120
+
121
+ self.vision_feature_select_strategy = vision_feature_select_strategy
122
+ self.vision_feature_layer = vision_feature_layer
123
+ image_grid_pinpoints = (
124
+ image_grid_pinpoints
125
+ if image_grid_pinpoints is not None
126
+ else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
127
+ )
128
+ self.image_grid_pinpoints = image_grid_pinpoints
129
+
130
+ if isinstance(vision_config, dict):
131
+ vision_config["model_type"] = (
132
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
133
+ )
134
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
135
+ elif vision_config is None:
136
+ vision_config = CONFIG_MAPPING["clip_vision_model"](
137
+ intermediate_size=4096,
138
+ hidden_size=1024,
139
+ patch_size=14,
140
+ image_size=336,
141
+ num_hidden_layers=24,
142
+ num_attention_heads=16,
143
+ vocab_size=32000,
144
+ projection_dim=768,
145
+ )
146
+
147
+ self.vision_config = vision_config
148
+
149
+ if isinstance(text_config, dict):
150
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
151
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
152
+ elif text_config is None:
153
+ text_config = CONFIG_MAPPING["llama"]()
154
+
155
+ self.text_config = text_config
156
+
157
+ super().__init__(**kwargs)
158
+
159
+
160
+
161
+
162
+
163
+
164
+ _CONFIG_FOR_DOC = "LlavaNextConfig"
165
+
166
+ LLAVA_NEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
167
+ "llava-hf/llava-v1.6-mistral-7b-hf",
168
+ # See all LLaVA-NeXT models at https://huggingface.co/models?filter=llava_next
169
+ ]
170
+
171
+
172
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
173
+ """
174
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
175
+
176
+ Args:
177
+ image_size (`tuple`):
178
+ The size of the input image in the format (width, height).
179
+ grid_pinpoints (`List`):
180
+ A list containing possible resolutions. Each item in the list should be a tuple or list
181
+ of the form `(height, width)`.
182
+ patch_size (`int`):
183
+ The size of each image patch.
184
+
185
+ Returns:
186
+ tuple: The shape of the image patch grid in the format (width, height).
187
+ """
188
+ if not isinstance(grid_pinpoints, list):
189
+ raise ValueError("grid_pinpoints should be a list of tuples or lists")
190
+
191
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
192
+ if not isinstance(image_size, (list, tuple)):
193
+ assert isinstance(image_size, (torch.Tensor, np.ndarray)), f'image_size invalid type: {type(image_size)} | {image_size}'
194
+ image_size = image_size.tolist()
195
+
196
+ height, width = select_best_resolution(image_size, grid_pinpoints)
197
+ return height // patch_size, width // patch_size
198
+
199
+
200
+ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
201
+ """
202
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
203
+
204
+ Args:
205
+ image_size (`tuple`):
206
+ The size of the input image in the format (height, width). ?
207
+ grid_pinpoints (`List`):
208
+ A list containing possible resolutions. Each item in the list should be a tuple or list
209
+ of the form `(height, width)`.
210
+ patch_size (`int`):
211
+ The size of each image patch.
212
+
213
+ Returns:
214
+ tuple: The shape of the image patch grid in the format (height, width). ?
215
+ """
216
+ if not isinstance(grid_pinpoints, list):
217
+ raise ValueError("grid_pinpoints should be a list of tuples or lists")
218
+
219
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
220
+ if not isinstance(image_size, (list, tuple)):
221
+ assert isinstance(image_size, (torch.Tensor, np.ndarray)), f'image_size invalid type: {type(image_size)} | {image_size}'
222
+ image_size = image_size.tolist()
223
+
224
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
225
+ height, width = best_resolution
226
+ num_patches = 0
227
+ for i in range(0, height, patch_size):
228
+ for j in range(0, width, patch_size):
229
+ num_patches += 1
230
+ # add the base patch
231
+ num_patches += 1
232
+ return num_patches
233
+
234
+
235
+ def unpad_image(tensor, original_size):
236
+ """
237
+ Unpads a PyTorch tensor of a padded and resized image.
238
+
239
+ Args:
240
+ tensor (`torch.Tensor`):
241
+ The image tensor, assumed to be of shape (num_channels, height, width).
242
+ original_size (`tuple`):
243
+ The original size of the image (height, width).
244
+
245
+ Returns:
246
+ `torch.Tensor`: The unpadded image tensor.
247
+ """
248
+ original_height, original_width = original_size
249
+ current_height, current_width = tensor.shape[1:]
250
+
251
+ original_aspect_ratio = original_width / original_height
252
+ current_aspect_ratio = current_width / current_height
253
+
254
+ if original_aspect_ratio > current_aspect_ratio:
255
+ scale_factor = current_width / original_width
256
+ new_height = int(original_height * scale_factor)
257
+ padding = (current_height - new_height) // 2
258
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
259
+ else:
260
+ scale_factor = current_height / original_height
261
+ new_width = int(original_width * scale_factor)
262
+ padding = (current_width - new_width) // 2
263
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
264
+
265
+ return unpadded_tensor
266
+
267
+
268
+ @dataclass
269
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->LlavaNext
270
+ class LlavaNextCausalLMOutputWithPast(ModelOutput):
271
+ """
272
+ Base class for LlavaNext causal language model (or autoregressive) outputs.
273
+
274
+ Args:
275
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
276
+ Language modeling loss (for next-token prediction).
277
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
278
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
279
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
280
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
281
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
282
+
283
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
284
+ `past_key_values` input) to speed up sequential decoding.
285
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
286
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
287
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
288
+
289
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
290
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
291
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
292
+ sequence_length)`.
293
+
294
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
295
+ heads.
296
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
297
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
298
+ sequence_length, hidden_size)`.
299
+
300
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
301
+ """
302
+
303
+ loss: Optional[torch.FloatTensor] = None
304
+ logits: torch.FloatTensor = None
305
+ past_key_values: Optional[List[torch.FloatTensor]] = None
306
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
307
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
308
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
309
+
310
+
311
+ # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
312
+ class LlavaNextMultiModalProjector(nn.Module):
313
+ def __init__(self, config: LlavaNextConfig):
314
+ super().__init__()
315
+
316
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
317
+ self.act = ACT2FN[config.projector_hidden_act]
318
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
319
+
320
+ def forward(self, image_features):
321
+ hidden_states = self.linear_1(image_features)
322
+ hidden_states = self.act(hidden_states)
323
+ hidden_states = self.linear_2(hidden_states)
324
+ return hidden_states
325
+
326
+
327
+ LLAVA_NEXT_START_DOCSTRING = r"""
328
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
329
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
330
+ etc.)
331
+
332
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
333
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
334
+ and behavior.
335
+
336
+ Parameters:
337
+ config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]):
338
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
339
+ load the weights associated with the model, only the configuration. Check out the
340
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
341
+ """
342
+
343
+
344
+ @add_start_docstrings(
345
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
346
+ LLAVA_NEXT_START_DOCSTRING,
347
+ )
348
+ # Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNext,llava->llava_next
349
+ class LlavaNextPreTrainedModel(PreTrainedModel):
350
+ config_class = LlavaNextConfig
351
+ base_model_prefix = "model"
352
+ supports_gradient_checkpointing = True
353
+ _no_split_modules = ["LlavaNextVisionAttention"]
354
+ _skip_keys_device_placement = "past_key_values"
355
+ _supports_flash_attn_2 = True
356
+
357
+ def _init_weights(self, module):
358
+ # important: this ported version of LlavaNext isn't meant for training from scratch - only
359
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
360
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
361
+ std = (
362
+ self.config.initializer_range
363
+ if hasattr(self.config, "initializer_range")
364
+ else self.config.text_config.initializer_range
365
+ )
366
+
367
+ if hasattr(module, "class_embedding"):
368
+ module.class_embedding.data.normal_(mean=0.0, std=std)
369
+
370
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
371
+ module.weight.data.normal_(mean=0.0, std=std)
372
+ if module.bias is not None:
373
+ module.bias.data.zero_()
374
+ elif isinstance(module, nn.Embedding):
375
+ module.weight.data.normal_(mean=0.0, std=std)
376
+ if module.padding_idx is not None:
377
+ module.weight.data[module.padding_idx].zero_()
378
+
379
+ @property
380
+ def _supports_sdpa(self):
381
+ """
382
+ Retrieve language_model's attribute to check whether the model supports
383
+ SDPA or not.
384
+ """
385
+ return self.language_model._supports_sdpa
386
+
387
+
388
+ LLAVA_NEXT_INPUTS_DOCSTRING = r"""
389
+ Args:
390
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
391
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
392
+ it.
393
+
394
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
395
+ [`PreTrainedTokenizer.__call__`] for details.
396
+
397
+ [What are input IDs?](../glossary#input-ids)
398
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
399
+ The tensors corresponding to the input images. Pixel values can be obtained using
400
+ [`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
401
+ [`LlavaNextImageProcessor`] for processing images.
402
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
403
+ The sizes of the images in the batch, being (height, width) for each image.
404
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
405
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
406
+
407
+ - 1 for tokens that are **not masked**,
408
+ - 0 for tokens that are **masked**.
409
+
410
+ [What are attention masks?](../glossary#attention-mask)
411
+
412
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
413
+ [`PreTrainedTokenizer.__call__`] for details.
414
+
415
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
416
+ `past_key_values`).
417
+
418
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
419
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
420
+ information on the default strategy.
421
+
422
+ - 1 indicates the head is **not masked**,
423
+ - 0 indicates the head is **masked**.
424
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
425
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
426
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
427
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
428
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
429
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
430
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
431
+
432
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
433
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
434
+
435
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
436
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
437
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
438
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
439
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
440
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
441
+ model's internal embedding lookup matrix.
442
+ vision_feature_layer (`int`, *optional*, defaults to -2):
443
+ The index of the layer to select the vision feature.
444
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
445
+ The feature selection strategy used to select the vision feature from the vision backbone.
446
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
447
+ If `"full"`, the full vision features are used.
448
+ use_cache (`bool`, *optional*):
449
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
450
+ `past_key_values`).
451
+ output_attentions (`bool`, *optional*):
452
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
453
+ tensors for more detail.
454
+ output_hidden_states (`bool`, *optional*):
455
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
456
+ more detail.
457
+ return_dict (`bool`, *optional*):
458
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
459
+ """
460
+
461
+
462
+ @add_start_docstrings(
463
+ """The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
464
+ LLAVA_NEXT_START_DOCSTRING,
465
+ )
466
+ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
467
+ def __init__(self, config: LlavaNextConfig):
468
+ super().__init__(config)
469
+ self.vision_tower = AutoModel.from_config(config.vision_config)
470
+
471
+ self.multi_modal_projector = LlavaNextMultiModalProjector(config)
472
+
473
+ self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype))
474
+
475
+ self.vocab_size = config.text_config.vocab_size
476
+ self.language_model = AutoModelForCausalLM.from_config(
477
+ config.text_config, attn_implementation=config._attn_implementation
478
+ )
479
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
480
+ self.post_init()
481
+
482
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
483
+ def get_input_embeddings(self):
484
+ return self.language_model.get_input_embeddings()
485
+
486
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
487
+ def set_input_embeddings(self, value):
488
+ self.language_model.set_input_embeddings(value)
489
+
490
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
491
+ def get_output_embeddings(self):
492
+ return self.language_model.get_output_embeddings()
493
+
494
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
495
+ def set_output_embeddings(self, new_embeddings):
496
+ self.language_model.set_output_embeddings(new_embeddings)
497
+
498
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
499
+ def set_decoder(self, decoder):
500
+ self.language_model.set_decoder(decoder)
501
+
502
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
503
+ def get_decoder(self):
504
+ return self.language_model.get_decoder()
505
+
506
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
507
+ def tie_weights(self):
508
+ return self.language_model.tie_weights()
509
+
510
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
511
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
512
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
513
+ # update vocab size
514
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
515
+ self.vocab_size = model_embeds.num_embeddings
516
+ return model_embeds
517
+
518
+ def _merge_input_ids_with_image_features(
519
+ self, image_features, feature_lens, inputs_embeds, input_ids, attention_mask, position_ids=None,
520
+ labels=None, image_token_index=None,
521
+ ignore_index=-100,
522
+ padding_side: Optional[str] = "left",
523
+ ):
524
+ """
525
+ Args:
526
+ input_ids: [batch_size, tlen]
527
+ input_embeds: [batch_size, tlen, dt]
528
+ image_features: [all_feat_lens, di]
529
+ feature_lens: [num_images],
530
+ num_images=number of image in the batch
531
+ each value is the length of embedding featres of each image
532
+ Note: sum(feature_lens) == all_feat_lens
533
+ labels: None or [batch_size, tlen] --> must extend labels to input_ids,
534
+ padding_side: `left` or `right`,
535
+ must specify for generation because we cannot tell that from input_ids
536
+ see below
537
+ Returns:
538
+ final_embedding, final_attention_mask, position_ids, final_labels
539
+
540
+ Explanation:
541
+ each image has variable length embeddings, with length specified by feature_lens
542
+ image_features is concatenation of all visual embed vectors
543
+ task: fill each <image> with the correct number of visual embeddings
544
+ Example:
545
+ X (5 patches), Y (3 patches), Z (8)
546
+ X, Y is on the same sequence (in-context learning)
547
+ if right padding
548
+ input_ids: [
549
+ a b c d e f X g h i j k Y l m
550
+ o p q r Z s t u v _ _ _ _ _ _
551
+ ]
552
+ input_ids should be: [
553
+ a b c d e f X X X X X g h i j k Y Y Y l m
554
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
555
+ ]
556
+ labels should be: [
557
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
558
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
559
+ ]
560
+ elif left padding
561
+ input_ids: [
562
+ a b c d e f X g h i j k Y l m
563
+ _ _ _ _ _ _ o p q r Z s t u v
564
+ ]
565
+ input_ids should be: [
566
+ a b c d e f X X X X X g h i j k Y Y Y l m
567
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
568
+ ]
569
+ labels should be: [
570
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
571
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
572
+ ]
573
+ Edge cases:
574
+ * If tokens are same but image token sizes are different, then cannot infer left or right padding
575
+ ```python
576
+ cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
577
+ chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
578
+ prompts = [
579
+ "[INST] <image>\nWhat is shown in this image? [/INST]",
580
+ "[INST] <image>\nWhat is shown in this image? [/INST]",
581
+ ]
582
+ inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
583
+ chart_img has 2634 tokens, while cat_img has 2340 tokens
584
+ ```
585
+
586
+ input_ids: [
587
+ a b c d X g h
588
+ i j Y k l m n
589
+ ]
590
+ where X is 3 tokens while Y is 5, this mean after merge
591
+ if left-padding (batched generation)
592
+ input_ids should be: [
593
+ _ _ a b c d X X X g h
594
+ i j Y Y Y Y Y k l m n
595
+ ]
596
+ elif (right padding) (training)
597
+ input_ids should be: [
598
+ a b c d X X X g h _ _
599
+ i j Y Y Y Y Y k l m n
600
+ ]
601
+
602
+
603
+ """
604
+ image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
605
+ ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
606
+
607
+ with torch.no_grad():
608
+ # ! in llava 1.6, number of patches is variable
609
+ num_images = feature_lens.size(0)
610
+ num_image_features, embed_dim = image_features.shape
611
+ assert feature_lens.sum() == num_image_features, f'{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}'
612
+ batch_size, sequence_length = input_ids.shape
613
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
614
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
615
+
616
+ if _left_padding and not _right_padding:
617
+ left_padding = True
618
+ elif not _left_padding and _right_padding:
619
+ left_padding = False
620
+ elif not _left_padding and not _right_padding:
621
+ # both side is 1, so cannot tell
622
+ left_padding = padding_side == "left"
623
+ else:
624
+ # invalid attention_mask
625
+ raise ValueError(f'both side of attention_mask has zero, invalid. {attention_mask}')
626
+
627
+ # Whether to turn off right padding
628
+ # 1. Create a mask to know where special image tokens are
629
+ special_image_token_mask = input_ids == image_token_index
630
+ # special_image_token_mask: [bsz, seqlen]
631
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
632
+ # num_special_image_tokens: [bsz]
633
+ # Reserve for padding of num_images
634
+ total_num_special_image_tokens = torch.sum(special_image_token_mask)
635
+ assert total_num_special_image_tokens == num_images, (
636
+ f'{total_num_special_image_tokens=} != {num_images=} | {image_features.shape} {input_ids}'
637
+ )
638
+ # Compute the maximum embed dimension
639
+ # max_image_feature_lens is max_feature_lens per batch
640
+ feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
641
+ feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
642
+ embed_sequence_lengths = (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
643
+ max_embed_dim = embed_sequence_lengths.max()
644
+
645
+ batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
646
+ # 2. Compute the positions where text should be written
647
+ # Calculate new positions for text tokens in merged image-text sequence.
648
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
649
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
650
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
651
+ # ! instead of special_image_token_mask * (num_image_patches - 1)
652
+ # special_image_token_mask * (num_feature_len - 1)
653
+ special_image_len_mask = special_image_token_mask.clone().long()
654
+ special_image_len_mask[special_image_len_mask == 1] = feature_lens - 1
655
+ new_token_positions = torch.cumsum((special_image_len_mask + 1), -1) - 1
656
+ if left_padding:
657
+ # shift right token positions so that they are ending at the same number
658
+ new_token_positions += (new_token_positions[:, -1].max() - new_token_positions[:, -1:])
659
+
660
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
661
+
662
+ # 3. Create the full embedding, already padded to the maximum position
663
+ final_embedding = torch.zeros(
664
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
665
+ )
666
+ final_attention_mask = torch.zeros(
667
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
668
+ )
669
+ final_labels = None
670
+ if labels is not None:
671
+ final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
672
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
673
+ # set the corresponding tensors into their correct target device.
674
+ target_device = inputs_embeds.device
675
+ batch_indices, non_image_indices, text_to_overwrite = (
676
+ batch_indices.to(target_device),
677
+ non_image_indices.to(target_device),
678
+ text_to_overwrite.to(target_device),
679
+ )
680
+ attention_mask = attention_mask.to(target_device)
681
+
682
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
683
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
684
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
685
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
686
+ if labels is not None:
687
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
688
+
689
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
690
+ with torch.no_grad():
691
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
692
+ if left_padding:
693
+ # exclude padding on the left
694
+ val = (max_embed_dim - torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim)) <= embed_sequence_lengths[:, None].to(target_device)
695
+ image_to_overwrite &= val
696
+ else:
697
+ # exclude padding on the right
698
+ val = torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim) < embed_sequence_lengths[:, None].to(target_device)
699
+ image_to_overwrite &= val
700
+
701
+ if image_to_overwrite.sum() != num_image_features:
702
+ raise ValueError(
703
+ f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
704
+ f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
705
+ f" the number of image given to the model is {num_images}. "
706
+ f"This prevents correct indexing and breaks batch generation."
707
+ )
708
+ final_embedding[image_to_overwrite] = image_features.to(target_device)
709
+ final_attention_mask |= image_to_overwrite
710
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
711
+
712
+ if not left_padding:
713
+ # Making sure its the same
714
+ seq_lens = final_attention_mask.sum(-1)
715
+ for i, (mask, seq_len) in enumerate(zip(final_attention_mask, seq_lens)):
716
+ # seq_len = mask.sum(-1)
717
+ assert torch.all(mask[:seq_len] == 1), f'final 1 mask[{i}]: {seq_len=} {final_attention_mask.size()=} {final_attention_mask.tolist()=} \n{text_to_overwrite.tolist()=}'
718
+ assert torch.all(mask[seq_len:] == 0), f'final 0 mask[{i}]: {seq_len=} {final_attention_mask.size()=} {final_attention_mask.tolist()=}'
719
+
720
+ return final_embedding, final_attention_mask, position_ids, final_labels
721
+
722
+ def pack_image_features(self, image_features, image_sizes, image_newline=None):
723
+ """
724
+ List of image features
725
+ image_features: list (size num_images) [patches, feat, dim]
726
+ Returns:
727
+ image_features: [all_feat_len, embed_dim]
728
+ feature_lens: [num_images] # number of feature_lens
729
+ """
730
+ new_image_features = []
731
+ feature_lens = []
732
+ for image_idx, image_feature in enumerate(image_features):
733
+ if image_feature.shape[0] > 1:
734
+ base_image_feature = image_feature[0]
735
+ image_feature = image_feature[1:]
736
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
737
+ if height * width != base_image_feature.shape[0]:
738
+ raise ValueError("The number of patches is not consistent with the image size.")
739
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
740
+ image_sizes[image_idx],
741
+ self.config.image_grid_pinpoints,
742
+ self.config.vision_config.image_size,
743
+ )
744
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
745
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
746
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
747
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
748
+ if image_newline is not None:
749
+ image_feature = torch.cat(
750
+ (
751
+ image_feature,
752
+ image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature),
753
+ ),
754
+ dim=-1,
755
+ )
756
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
757
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
758
+ else:
759
+ image_feature = image_feature[0]
760
+ if image_newline is not None:
761
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
762
+ new_image_features.append(image_feature)
763
+ feature_lens.append(image_feature.size(0))
764
+ image_features = torch.cat(new_image_features, dim=0)
765
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
766
+ return image_features, feature_lens
767
+
768
+ @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
769
+ @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
770
+ def forward(
771
+ self,
772
+ input_ids: torch.LongTensor = None,
773
+ pixel_values: torch.FloatTensor = None,
774
+ image_sizes: Optional[torch.LongTensor] = None,
775
+ attention_mask: Optional[torch.Tensor] = None,
776
+ position_ids: Optional[torch.LongTensor] = None,
777
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
778
+ inputs_embeds: Optional[torch.FloatTensor] = None,
779
+ vision_feature_layer: Optional[int] = None,
780
+ vision_feature_select_strategy: Optional[str] = None,
781
+ labels: Optional[torch.LongTensor] = None,
782
+ use_cache: Optional[bool] = None,
783
+ output_attentions: Optional[bool] = None,
784
+ output_hidden_states: Optional[bool] = None,
785
+ return_dict: Optional[bool] = None,
786
+ padding_side: Optional[str] = "left",
787
+ ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
788
+ r"""
789
+ Args:
790
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
791
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
792
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
793
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
794
+
795
+ Returns:
796
+
797
+ Example:
798
+
799
+ ```python
800
+ >>> from PIL import Image
801
+ >>> import requests
802
+ >>> from transformers import AutoProcessor, LlavaNextForConditionalGeneration
803
+
804
+ >>> model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
805
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
806
+
807
+ >>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
808
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
809
+ >>> image = Image.open(requests.get(url, stream=True).raw)
810
+
811
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
812
+
813
+ >>> # Generate
814
+ >>> generate_ids = model.generate(**inputs, max_length=30)
815
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
816
+ "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
817
+ ```"""
818
+
819
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
820
+ output_hidden_states = (
821
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
822
+ )
823
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
824
+ vision_feature_layer = (
825
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
826
+ )
827
+ vision_feature_select_strategy = (
828
+ vision_feature_select_strategy
829
+ if vision_feature_select_strategy is not None
830
+ else self.config.vision_feature_select_strategy
831
+ )
832
+
833
+ if inputs_embeds is None:
834
+ # 1. Extract the input embeddings
835
+ # In case image_token_index is not in the embeddings (extra token but embedding don't have it)
836
+ for_inputs_embeds_ids = input_ids.clone()
837
+ for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
838
+ inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
839
+
840
+ # 2. Merge text and images
841
+ if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
842
+ # ! infer image_num_patches from image_sizes
843
+ image_num_patches = [
844
+ image_size_to_num_patches(
845
+ image_size=imsize,
846
+ grid_pinpoints=self.config.image_grid_pinpoints,
847
+ patch_size=self.config.vision_config.image_size
848
+ )
849
+ for imsize in image_sizes
850
+ ]
851
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
852
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
853
+
854
+ if vision_feature_select_strategy == "default":
855
+ selected_image_feature = selected_image_feature[:, 1:]
856
+ elif vision_feature_select_strategy == "full":
857
+ selected_image_feature = selected_image_feature
858
+
859
+ image_features = self.multi_modal_projector(selected_image_feature)
860
+
861
+ image_features = torch.split(image_features, image_num_patches, dim=0)
862
+
863
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
864
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
865
+
866
+ image_features, feature_lens = self.pack_image_features(
867
+ image_features, image_sizes,
868
+ image_newline=self.image_newline,
869
+ )
870
+
871
+ inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
872
+ image_features, feature_lens, inputs_embeds, input_ids, attention_mask, position_ids,
873
+ labels=labels,
874
+ padding_side=padding_side,
875
+ )
876
+
877
+ # pixel_values is not None but is empty ---> text only cases
878
+ elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
879
+ # there is no images
880
+ pass
881
+
882
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
883
+ # generation with cache
884
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
885
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
886
+ # that are set to 0
887
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
888
+
889
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
890
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
891
+
892
+ # Get the target length
893
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
894
+
895
+ extended_attention_mask = torch.ones(
896
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
897
+ dtype=attention_mask.dtype,
898
+ device=attention_mask.device,
899
+ )
900
+
901
+ # Filter out only the tokens that can be un-attended, this can happen
902
+ # if one uses Llava + Fused modules where the cache on the
903
+ # first iteration is already big enough, or if one passes custom cache
904
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
905
+ new_batch_index = batch_index[valid_indices]
906
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
907
+
908
+ # Zero-out the places where we don't need to attend
909
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
910
+
911
+ # !(nxphi47) must ensure left-padding
912
+ # attention_mask is the new in-coming mask, while extended_attention_mask is the previous one
913
+ assert padding_side == "left", f"{padding_side=} is invalid for batched generation mode"
914
+ attention_mask = torch.cat((extended_attention_mask, attention_mask), dim=1)
915
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
916
+
917
+
918
+ outputs = self.language_model(
919
+ attention_mask=attention_mask,
920
+ position_ids=position_ids,
921
+ past_key_values=past_key_values,
922
+ inputs_embeds=inputs_embeds,
923
+ use_cache=use_cache,
924
+ output_attentions=output_attentions,
925
+ output_hidden_states=output_hidden_states,
926
+ return_dict=return_dict,
927
+ )
928
+
929
+ logits = outputs[0]
930
+
931
+ loss = None
932
+ if labels is not None:
933
+ # Shift so that tokens < n predict n
934
+ if attention_mask is not None:
935
+ shift_attention_mask = attention_mask[..., 1:]
936
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
937
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
938
+ else:
939
+ shift_logits = logits[..., :-1, :].contiguous()
940
+ shift_labels = labels[..., 1:].contiguous()
941
+ # Flatten the tokens
942
+ loss_fct = nn.CrossEntropyLoss()
943
+ loss = loss_fct(
944
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
945
+ )
946
+
947
+ if not return_dict:
948
+ output = (logits,) + outputs[1:]
949
+ return (loss,) + output if loss is not None else output
950
+
951
+ return LlavaNextCausalLMOutputWithPast(
952
+ loss=loss,
953
+ logits=logits,
954
+ past_key_values=outputs.past_key_values,
955
+ hidden_states=outputs.hidden_states,
956
+ attentions=outputs.attentions,
957
+ )
958
+
959
+ def prepare_inputs_for_generation(
960
+ self,
961
+ input_ids,
962
+ past_key_values=None,
963
+ inputs_embeds=None,
964
+ pixel_values=None,
965
+ image_sizes=None,
966
+ attention_mask=None,
967
+ **kwargs,
968
+ ):
969
+ if past_key_values is not None:
970
+ if isinstance(past_key_values, Cache):
971
+ cache_length = past_key_values.get_seq_length()
972
+ past_length = past_key_values.seen_tokens
973
+ else:
974
+ cache_length = past_length = past_key_values[0][0].shape[2]
975
+
976
+ # Keep only the unprocessed tokens:
977
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
978
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
979
+ # input)
980
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
981
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
982
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
983
+ # input_ids based on the past_length.
984
+ elif past_length < input_ids.shape[1]:
985
+ input_ids = input_ids[:, past_length:]
986
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
987
+ elif self.config.image_token_index in input_ids:
988
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
989
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
990
+ # older attention values, as their corresponding values are not part of the input.
991
+ if cache_length < past_length and attention_mask is not None:
992
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
993
+
994
+ position_ids = kwargs.get("position_ids", None)
995
+ if attention_mask is not None and position_ids is None:
996
+ # create position_ids on the fly for batch generation
997
+ position_ids = attention_mask.long().cumsum(-1) - 1
998
+ position_ids.masked_fill_(attention_mask == 0, 1)
999
+ if past_key_values:
1000
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1001
+
1002
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1003
+ if inputs_embeds is not None and past_key_values is None:
1004
+ model_inputs = {"inputs_embeds": inputs_embeds}
1005
+ else:
1006
+ model_inputs = {"input_ids": input_ids}
1007
+
1008
+ model_inputs.update(
1009
+ {
1010
+ "position_ids": position_ids,
1011
+ "past_key_values": past_key_values,
1012
+ "use_cache": kwargs.get("use_cache"),
1013
+ "attention_mask": attention_mask,
1014
+ "pixel_values": pixel_values,
1015
+ "image_sizes": image_sizes,
1016
+ }
1017
+ )
1018
+ return model_inputs
1019
+
1020
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._reorder_cache
1021
+ def _reorder_cache(self, *args, **kwargs):
1022
+ return self.language_model._reorder_cache(*args, **kwargs)
multipurpose_chatbot/engines/processing_llava_next.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for LLaVa-NeXT.
17
+ """
18
+
19
+
20
+ from typing import List, Optional, Union
21
+
22
+ from transformers.feature_extraction_utils import BatchFeature
23
+ from transformers.image_utils import ImageInput
24
+ from transformers.processing_utils import ProcessorMixin
25
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
26
+ from transformers.utils import TensorType
27
+
28
+
29
+ class LlavaNextProcessor(ProcessorMixin):
30
+ r"""
31
+ Constructs a LLaVa-NeXT processor which wraps a LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.
32
+
33
+ [`LlavaNextProcessor`] offers all the functionalities of [`LlavaNextImageProcessor`] and [`LlamaTokenizerFast`]. See the
34
+ [`~LlavaNextProcessor.__call__`] and [`~LlavaNextProcessor.decode`] for more information.
35
+
36
+ Args:
37
+ image_processor ([`LlavaNextImageProcessor`], *optional*):
38
+ The image processor is a required input.
39
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
40
+ The tokenizer is a required input.
41
+ """
42
+
43
+ attributes = ["image_processor", "tokenizer"]
44
+ image_processor_class = "LlavaNextImageProcessor"
45
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
46
+
47
+ def __init__(self, image_processor=None, tokenizer=None):
48
+ super().__init__(image_processor, tokenizer)
49
+
50
+ def __call__(
51
+ self,
52
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
53
+ images: ImageInput = None,
54
+ padding: Union[bool, str, PaddingStrategy] = False,
55
+ truncation: Union[bool, str, TruncationStrategy] = None,
56
+ max_length=None,
57
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
58
+ ) -> BatchFeature:
59
+ """
60
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
61
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
62
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
63
+ LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
64
+ of the above two methods for more information.
65
+
66
+ Args:
67
+ text (`str`, `List[str]`, `List[List[str]]`):
68
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
69
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
70
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
71
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
72
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
73
+ tensor. Both channels-first and channels-last formats are supported.
74
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
75
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
76
+ index) among:
77
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
78
+ sequence if provided).
79
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
80
+ acceptable input length for the model if that argument is not provided.
81
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
82
+ lengths).
83
+ max_length (`int`, *optional*):
84
+ Maximum length of the returned list and optionally padding length (see above).
85
+ truncation (`bool`, *optional*):
86
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
87
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
88
+ If set, will return tensors of a particular framework. Acceptable values are:
89
+
90
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
91
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
92
+ - `'np'`: Return NumPy `np.ndarray` objects.
93
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
94
+
95
+ Returns:
96
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
97
+
98
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
99
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
100
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
101
+ `None`).
102
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
103
+ """
104
+ if images is not None:
105
+ image_inputs = self.image_processor(images, return_tensors=return_tensors)
106
+ else:
107
+ image_inputs = {}
108
+ text_inputs = self.tokenizer(
109
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
110
+ )
111
+
112
+ return BatchFeature(data={**text_inputs, **image_inputs})
113
+
114
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
115
+ def batch_decode(self, *args, **kwargs):
116
+ """
117
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
118
+ refer to the docstring of this method for more information.
119
+ """
120
+ return self.tokenizer.batch_decode(*args, **kwargs)
121
+
122
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
123
+ def decode(self, *args, **kwargs):
124
+ """
125
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
126
+ the docstring of this method for more information.
127
+ """
128
+ return self.tokenizer.decode(*args, **kwargs)
129
+
130
+ @property
131
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
132
+ def model_input_names(self):
133
+ tokenizer_input_names = self.tokenizer.model_input_names
134
+ image_processor_input_names = self.image_processor.model_input_names
135
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
multipurpose_chatbot/engines/sealava16_transformers_engine.py CHANGED
@@ -162,8 +162,11 @@ class SeaLlava16Engine(TransformersEngine):
162
  sys.path.append(CODE_PATH)
163
 
164
 
165
- from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
166
- from transformers.models.llava_next.processing_llava_next import LlavaNextProcessor
 
 
 
167
  model_path = MODEL_PATH
168
  print(f'Loading model from {model_path}')
169
 
@@ -171,8 +174,12 @@ class SeaLlava16Engine(TransformersEngine):
171
  if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"):
172
  os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin")
173
 
174
- self._processor = LlavaNextProcessor.from_pretrained(model_path)
 
 
 
175
  self._model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval()
 
176
 
177
  self._model.sample_old = self._model.sample
178
  self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
 
162
  sys.path.append(CODE_PATH)
163
 
164
 
165
+ # from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
166
+ # from transformers.models.llava_next.processing_llava_next import LlavaNextProcessor
167
+ from .modeling_sealava16 import LlavaNextForConditionalGeneration
168
+ from .image_processing_llava_next import LlavaNextImageProcessor
169
+ from .processing_llava_next import LlavaNextProcessor
170
  model_path = MODEL_PATH
171
  print(f'Loading model from {model_path}')
172
 
 
174
  if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"):
175
  os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin")
176
 
177
+ # self._processor = LlavaNextProcessor.from_pretrained(model_path)
178
+ self._tokenizer = AutoTokenizer.from_pretrained(model_path)
179
+ self._image_processor = LlavaNextImageProcessor(model_path)
180
+ self._processor = LlavaNextProcessor(image_processor=self._image_processor, tokenizer=self._tokenizer)
181
  self._model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval()
182
+ print(f'Loading llava1.6 from custom code')
183
 
184
  self._model.sample_old = self._model.sample
185
  self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)