Update modeling_minicpmv.py

#39
Files changed (1) hide show
  1. modeling_minicpmv.py +96 -7
modeling_minicpmv.py CHANGED
@@ -42,13 +42,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
42
 
43
  return model
44
 
45
- def init_resampler(self, embed_dim, vision_dim):
46
  return Resampler(
47
  num_queries=self.config.query_num,
48
  embed_dim=embed_dim,
49
  num_heads=embed_dim // 128,
50
  kv_dim=vision_dim,
51
- adaptive=True
52
  )
53
 
54
  def init_transform(self):
@@ -60,13 +60,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
60
  ),
61
  ]
62
  )
63
-
64
  def get_input_embeddings(self):
65
  return self.llm.get_input_embeddings()
66
 
67
  def set_input_embeddings(self, value):
68
  self.llm.embed_tokens = value
69
-
70
  def get_vllm_embedding(self, data):
71
  if 'vision_hidden_states' not in data:
72
  dtype = self.vpm.embeddings.position_embedding.weight.dtype
@@ -152,16 +152,105 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
152
  image_indices = torch.stack(
153
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
154
  ).to(vllm_embedding.device)
155
-
156
  cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
157
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
158
  elif self.training:
159
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
160
 
161
  return vllm_embedding, vision_hidden_states
162
-
163
  def forward(self, data, **kwargs):
164
- vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  position_ids = data["position_ids"]
166
  if position_ids.dtype != torch.int64:
167
  position_ids = position_ids.long()
 
42
 
43
  return model
44
 
45
+ def init_resampler(self, embed_dim, vision_dim,):
46
  return Resampler(
47
  num_queries=self.config.query_num,
48
  embed_dim=embed_dim,
49
  num_heads=embed_dim // 128,
50
  kv_dim=vision_dim,
51
+ adaptive=True,
52
  )
53
 
54
  def init_transform(self):
 
60
  ),
61
  ]
62
  )
63
+
64
  def get_input_embeddings(self):
65
  return self.llm.get_input_embeddings()
66
 
67
  def set_input_embeddings(self, value):
68
  self.llm.embed_tokens = value
69
+
70
  def get_vllm_embedding(self, data):
71
  if 'vision_hidden_states' not in data:
72
  dtype = self.vpm.embeddings.position_embedding.weight.dtype
 
152
  image_indices = torch.stack(
153
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
154
  ).to(vllm_embedding.device)
 
155
  cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
156
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
157
  elif self.training:
158
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
159
 
160
  return vllm_embedding, vision_hidden_states
161
+
162
  def forward(self, data, **kwargs):
163
+
164
+ if 'vision_hidden_states' not in data:
165
+ dtype = self.llm.lm_head.weight.dtype
166
+ device = self.llm.lm_head.weight.device
167
+ tgt_sizes = data['tgt_sizes']
168
+ pixel_values_list = data['pixel_values']
169
+ vision_hidden_states = []
170
+ all_pixel_values = []
171
+ img_cnt = []
172
+ for pixel_values in pixel_values_list:
173
+ img_cnt.append(len(pixel_values))
174
+ all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
175
+
176
+ # exist image
177
+ if all_pixel_values:
178
+ tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
179
+
180
+ if self.config.batch_vision_input:
181
+ max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
182
+
183
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
184
+ padding_value=0.0)
185
+ B, L, _ = all_pixel_values.shape
186
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
187
+
188
+ patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
189
+ for i in range(B):
190
+ patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
191
+
192
+ vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state
193
+ vision_embedding = self.resampler(vision_embedding, tgt_sizes)
194
+ else:
195
+ # get vision_embedding foreach
196
+ vision_embedding = []
197
+ for single_tgt_size, single_pixel_values in zip(tgt_sizes, all_pixel_values):
198
+ single_pixel_values = single_pixel_values.unsqueeze(0)
199
+ B, L, _ = single_pixel_values.shape
200
+ single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
201
+ single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
202
+ single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
203
+ vision_embedding.append(single_vision_embedding)
204
+ vision_embedding = torch.vstack(vision_embedding)
205
+
206
+ start = 0
207
+ for pixel_values in pixel_values_list:
208
+ img_cnt = len(pixel_values)
209
+ if img_cnt > 0:
210
+ vision_hidden_states.append(vision_embedding[start: start + img_cnt])
211
+ start += img_cnt
212
+ else:
213
+ vision_hidden_states.append([])
214
+ else: # no image
215
+ if self.training:
216
+ dummy_image = torch.zeros(
217
+ (1, 3, 224, 224),
218
+ device=device, dtype=dtype
219
+ )
220
+ tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
221
+ dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
222
+ else:
223
+ dummy_feature = []
224
+ for _ in range(len(pixel_values_list)):
225
+ vision_hidden_states.append(dummy_feature)
226
+
227
+ else:
228
+ vision_hidden_states = data['vision_hidden_states']
229
+
230
+ if hasattr(self.llm.config, 'scale_emb'):
231
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
232
+ else:
233
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
234
+
235
+ vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
236
+ i, torch.Tensor) else i for i in vision_hidden_states]
237
+
238
+ bs = len(data['input_ids'])
239
+ for i in range(bs):
240
+ cur_vs_hs = vision_hidden_states[i]
241
+ if len(cur_vs_hs) > 0:
242
+ cur_vllm_emb = vllm_embedding[i]
243
+ cur_image_bound = data['image_bound'][i]
244
+ if len(cur_image_bound) > 0:
245
+ image_indices = torch.stack(
246
+ [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
247
+ ).to(vllm_embedding.device)
248
+ cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
249
+ cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
250
+ elif self.training:
251
+ cur_vllm_emb += cur_vs_hs[0].mean() * 0
252
+
253
+ # vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
254
  position_ids = data["position_ids"]
255
  if position_ids.dtype != torch.int64:
256
  position_ids = position_ids.long()