tttoaster commited on
Commit
07fe7de
·
verified ·
1 Parent(s): b78c67e

Upload any_res.py

Browse files
Files changed (1) hide show
  1. src/data/any_res.py +257 -0
src/data/any_res.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import torch
3
+ import math
4
+ import ast
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+
9
+ def select_best_resolution(original_size, possible_resolutions):
10
+ """
11
+ Selects the best resolution from a list of possible resolutions based on the original size.
12
+
13
+ Args:
14
+ original_size (tuple): The original size of the image in the format (width, height).
15
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
16
+
17
+ Returns:
18
+ tuple: The best fit resolution in the format (width, height).
19
+ """
20
+ original_width, original_height = original_size
21
+ best_fit = None
22
+ max_effective_resolution = 0
23
+ min_wasted_resolution = float('inf')
24
+
25
+ for width, height in possible_resolutions:
26
+ scale = min(width / original_width, height / original_height)
27
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
28
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
29
+ wasted_resolution = (width * height) - effective_resolution
30
+
31
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
32
+ max_effective_resolution = effective_resolution
33
+ min_wasted_resolution = wasted_resolution
34
+ best_fit = (width, height)
35
+
36
+ return best_fit
37
+
38
+
39
+ def select_best_resolution_v2(original_size, possible_resolutions):
40
+ """
41
+ Selects the best resolution from a list of possible resolutions based on the original size and aspect ratio.
42
+
43
+ Args:
44
+ original_size (tuple): The original size of the image in the format (width, height).
45
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
46
+
47
+ Returns:
48
+ tuple: The best fit resolution in the format (width, height).
49
+ """
50
+ original_width, original_height = original_size
51
+ original_aspect_ratio = original_height / original_width
52
+ original_area = original_width * original_height
53
+ best_fit = None
54
+ min_aspect_ratio_diff = float('inf')
55
+ min_area_ratio = float('inf')
56
+
57
+ for width, height in possible_resolutions:
58
+ aspect_ratio = height / width
59
+ area = width * height
60
+ aspect_ratio_diff = max(aspect_ratio, original_aspect_ratio) / min(aspect_ratio, original_aspect_ratio)
61
+ area_ratio = max(area, original_area) / min(area, original_area)
62
+
63
+ if aspect_ratio_diff < min_aspect_ratio_diff or (aspect_ratio_diff == min_aspect_ratio_diff and area_ratio < min_area_ratio):
64
+ min_aspect_ratio_diff = aspect_ratio_diff
65
+ min_area_ratio = area_ratio
66
+ best_fit = (width, height)
67
+
68
+ return best_fit
69
+
70
+
71
+ def resize_and_pad_image(image, target_resolution, keep_ratio=False):
72
+ """
73
+ Resize and pad an image to a target resolution
74
+
75
+ Args:
76
+ image (PIL.Image.Image): The input image.
77
+ target_resolution (tuple): The target resolution (width, height) of the image.
78
+
79
+ Returns:
80
+ PIL.Image.Image: The resized and padded image.
81
+ """
82
+ original_width, original_height = image.size
83
+ target_width, target_height = target_resolution
84
+
85
+ if keep_ratio:
86
+ # maintaining aspect ratio
87
+ scale_w = target_width / original_width
88
+ scale_h = target_height / original_height
89
+
90
+ if scale_w < scale_h:
91
+ new_width = target_width
92
+ new_height = min(math.ceil(original_height * scale_w), target_height)
93
+ else:
94
+ new_height = target_height
95
+ new_width = min(math.ceil(original_width * scale_h), target_width)
96
+
97
+ # Resize the image
98
+ resized_image = image.resize((new_width, new_height))
99
+
100
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
101
+ paste_x = (target_width - new_width) // 2
102
+ paste_y = (target_height - new_height) // 2
103
+ new_image.paste(resized_image, (paste_x, paste_y))
104
+ else:
105
+ # not maintaining aspect ratio
106
+ new_image = image.resize((target_width, target_height))
107
+
108
+ return new_image
109
+
110
+
111
+ def divide_to_patches(image, patch_size):
112
+ """
113
+ Divides an image into patches of a specified size.
114
+
115
+ Args:
116
+ image (PIL.Image.Image): The input image.
117
+ patch_size (int): The size of each patch.
118
+
119
+ Returns:
120
+ list: A list of PIL.Image.Image objects representing the patches.
121
+ """
122
+ patches = []
123
+ width, height = image.size
124
+ for i in range(0, height, patch_size):
125
+ for j in range(0, width, patch_size):
126
+ box = (j, i, j + patch_size, i + patch_size)
127
+ patch = image.crop(box)
128
+ patches.append(patch)
129
+
130
+ return patches
131
+
132
+
133
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
134
+ """
135
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
136
+
137
+ Args:
138
+ image_size (tuple): The size of the input image in the format (width, height).
139
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
140
+ patch_size (int): The size of each image patch.
141
+
142
+ Returns:
143
+ tuple: The shape of the image patch grid in the format (width, height).
144
+ """
145
+ if type(grid_pinpoints) is list:
146
+ possible_resolutions = grid_pinpoints
147
+ else:
148
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
149
+ width1, height1 = select_best_resolution(image_size, possible_resolutions)
150
+ width2, height2 = select_best_resolution_v2(image_size, possible_resolutions)
151
+ if width1*height1 > width2*height2:
152
+ width, height = width2, height2
153
+ else:
154
+ width, height = width1, height1
155
+ return width // patch_size, height // patch_size
156
+
157
+
158
+ def process_anyres_image(image, image_transform, grid_pinpoints, base_image_size):
159
+ """
160
+ Process an image with variable resolutions.
161
+
162
+ Args:
163
+ image (PIL.Image.Image): The input image to be processed.
164
+ image_transform: The image processor object.
165
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
166
+
167
+ Returns:
168
+ torch.Tensor: A tensor containing the processed image patches.
169
+ """
170
+ if type(grid_pinpoints) is list:
171
+ possible_resolutions = grid_pinpoints
172
+ else:
173
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
174
+ # best_resolution = select_best_resolution(image.size, possible_resolutions)
175
+ width1, height1 = select_best_resolution(image.size, possible_resolutions)
176
+ width2, height2 = select_best_resolution_v2(image.size, possible_resolutions)
177
+ if width1*height1 > width2*height2:
178
+ width, height = width2, height2
179
+ else:
180
+ width, height = width1, height1
181
+ best_resolution = [width, height]
182
+
183
+ image_padded = resize_and_pad_image(image, best_resolution)
184
+
185
+ patches = divide_to_patches(image_padded, base_image_size)
186
+
187
+ image_original_resize = image.resize((base_image_size, base_image_size))
188
+
189
+ image_patches = patches + [image_original_resize] # add the original image as the last patch
190
+ image_patches = [image_transform(image_patch)
191
+ for image_patch in image_patches]
192
+
193
+ patch_grid = (best_resolution[0]//base_image_size, best_resolution[1]//base_image_size)
194
+ x_index = (torch.arange(patch_grid[0]).repeat(patch_grid[1], 1) + 0.5)/patch_grid[0]
195
+ y_index = (torch.arange(patch_grid[1]).unsqueeze(1).repeat(1, patch_grid[0]) + 0.5)/patch_grid[1]
196
+ patch_pos = torch.stack([x_index, y_index], dim=-1).flatten(0, 1) # h*w, 2
197
+
198
+ origin_pos = torch.tensor([[0.5, 0.5]])
199
+ patch_pos = torch.cat([patch_pos, origin_pos], dim=0) # h*w+1, 2
200
+
201
+ return torch.stack(image_patches, dim=0), patch_pos
202
+
203
+
204
+ def load_image_from_base64(image):
205
+ return Image.open(BytesIO(base64.b64decode(image)))
206
+
207
+
208
+ def anyres_data_collate(batch, tokenizer, dataset_name=None):
209
+ results = {}
210
+ keys = batch[0].keys()
211
+
212
+ for key in keys:
213
+ cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
214
+ if len(cur) == 0:
215
+ results[key] = None
216
+ elif isinstance(cur[0], torch.Tensor):
217
+ if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']:
218
+ results[key] = torch.cat(cur, dim=0)
219
+ else:
220
+ if key in ['input_ids']:
221
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=tokenizer.pad_token_id)
222
+ elif key in ['attention_mask']:
223
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=0)
224
+ elif key in ['labels']:
225
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=-100)
226
+ elif key in ['ids_gen_mask', 'ids_cmp_mask']:
227
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=False)
228
+
229
+ else:
230
+ results[key] = torch.stack(cur, dim=0)
231
+ else:
232
+ results[key] = cur
233
+
234
+ results['dataset_name'] = dataset_name
235
+
236
+ return results
237
+
238
+
239
+ def anyres_data_collate_old(batch, dataset_name=None):
240
+ results = {}
241
+ keys = batch[0].keys()
242
+
243
+ for key in keys:
244
+ cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
245
+ if len(cur) == 0:
246
+ results[key] = None
247
+ elif isinstance(cur[0], torch.Tensor):
248
+ if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']:
249
+ results[key] = torch.cat(cur, dim=0)
250
+ else:
251
+ results[key] = torch.stack(cur, dim=0)
252
+ else:
253
+ results[key] = cur
254
+
255
+ results['dataset_name'] = dataset_name
256
+
257
+ return results