Update modeling_internlm_xcomposer2.py
#14
by
yuhangzang
- opened
- modeling_internlm_xcomposer2.py +81 -44
modeling_internlm_xcomposer2.py
CHANGED
@@ -287,69 +287,93 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
287 |
}
|
288 |
return inputs, wrap_im_mask, temp_len
|
289 |
|
290 |
-
def interleav_wrap(self, img_list, text_list):
|
291 |
-
|
292 |
-
|
|
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
parts = text.split('<ImageHere>')
|
298 |
-
|
|
|
299 |
temp_len = 0
|
300 |
-
image_nums, im_len = img_embeds.shape[:2]
|
301 |
need_bos = True
|
302 |
for idx, part in enumerate(parts):
|
303 |
if len(part) > 0:
|
304 |
-
part_tokens = self.tokenizer(
|
305 |
-
|
306 |
-
return_tensors='pt',
|
307 |
-
padding='longest',
|
308 |
-
add_special_tokens=need_bos).to(self.device)
|
309 |
if need_bos:
|
310 |
need_bos = False
|
311 |
wrap_tokens.append(part_tokens.input_ids)
|
312 |
-
part_embeds = self.model.tok_embeddings(
|
313 |
-
part_tokens.input_ids)
|
314 |
wrap_embeds.append(part_embeds)
|
315 |
-
|
316 |
-
wrap_im_mask.append(
|
317 |
-
torch.zeros(part_embeds.shape[:2]).to(self.device))
|
318 |
-
|
319 |
temp_len += part_embeds.shape[1]
|
320 |
-
if idx <
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
wrap_im_mask.append(
|
325 |
-
|
326 |
-
|
327 |
-
temp_len += im_len
|
328 |
if temp_len > self.max_length:
|
329 |
break
|
330 |
-
|
331 |
wrap_tokens = torch.cat(wrap_tokens, dim=1)
|
332 |
wrap_embeds = torch.cat(wrap_embeds, dim=1)
|
333 |
-
wrap_atts = torch.cat(wrap_atts, dim=1)
|
334 |
wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
|
335 |
|
336 |
wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
|
337 |
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
|
348 |
-
|
349 |
-
wrap_atts = torch.cat(wrap_atts_list)
|
350 |
-
wrap_target = torch.cat(wrap_target_list)
|
351 |
-
wrap_im_mask = torch.cat(wrap_im_mask_list)
|
352 |
-
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
353 |
|
354 |
def mask_human_targets(self, input_ids, pure=False):
|
355 |
target_batch = []
|
@@ -416,9 +440,22 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
416 |
text = samples['text_input']
|
417 |
# encode image
|
418 |
if has_img:
|
419 |
-
image = samples['image']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
|
421 |
-
image, text)
|
422 |
else:
|
423 |
to_regress_tokens, targets = self.text2emb(
|
424 |
text, add_special_tokens=True)
|
|
|
287 |
}
|
288 |
return inputs, wrap_im_mask, temp_len
|
289 |
|
290 |
+
def interleav_wrap(self, img_list, text_list, image_nums):
|
291 |
+
temp_embeds = []
|
292 |
+
temp_im_mask = []
|
293 |
+
temp_tars = []
|
294 |
|
295 |
+
# encode_image
|
296 |
+
img_embeds, img_split = self.vit(img_list, self.plora_glb_GN, self.plora_sub_GN)
|
297 |
+
img_embeds = self.vision_proj(img_embeds)
|
298 |
+
|
299 |
+
text_list = text_list[0]
|
300 |
+
for idx, text in enumerate(text_list):
|
301 |
+
image_num = image_nums[idx]
|
302 |
+
im_id = int(np.sum(image_nums[:idx]))
|
303 |
+
images = []
|
304 |
+
for i in range(image_nums[idx]):
|
305 |
+
st = int(np.sum(img_split[:im_id + i]))
|
306 |
+
sp = img_split[im_id + i]
|
307 |
+
temp_img = img_embeds[:, st:st+sp]
|
308 |
+
images.append(temp_img)
|
309 |
+
atts_img = torch.ones((len(images), images[0].shape[1]), dtype=torch.long).to(self.device)
|
310 |
+
img_target = torch.ones(
|
311 |
+
(len(images), images[0].shape[1]), dtype=torch.long).to(
|
312 |
+
self.device) * -100
|
313 |
+
|
314 |
+
if image_num == 1 and text.find('<ImageHere>') == -1:
|
315 |
+
text = '<ImageHere>' + text
|
316 |
parts = text.split('<ImageHere>')
|
317 |
+
|
318 |
+
wrap_tokens, wrap_embeds, wrap_im_mask = [], [], []
|
319 |
temp_len = 0
|
|
|
320 |
need_bos = True
|
321 |
for idx, part in enumerate(parts):
|
322 |
if len(part) > 0:
|
323 |
+
part_tokens = self.tokenizer(part, return_tensors='pt', padding='longest',
|
324 |
+
add_special_tokens=need_bos).to(self.device)
|
|
|
|
|
|
|
325 |
if need_bos:
|
326 |
need_bos = False
|
327 |
wrap_tokens.append(part_tokens.input_ids)
|
328 |
+
part_embeds = self.model.tok_embeddings(part_tokens.input_ids)
|
|
|
329 |
wrap_embeds.append(part_embeds)
|
330 |
+
wrap_im_mask.append(torch.zeros(part_embeds.shape[:2]).to(self.device))
|
|
|
|
|
|
|
331 |
temp_len += part_embeds.shape[1]
|
332 |
+
if idx < image_num:
|
333 |
+
wrap_embeds.append(images[idx])
|
334 |
+
wrap_token = torch.ones(images[idx].shape[:2], dtype=torch.long).to(self.device) * -100
|
335 |
+
wrap_tokens.append(wrap_token)
|
336 |
+
wrap_im_mask.append(torch.ones(images[idx].shape[:2]).to(self.device))
|
337 |
+
temp_len += images[idx].shape[1]
|
|
|
|
|
338 |
if temp_len > self.max_length:
|
339 |
break
|
|
|
340 |
wrap_tokens = torch.cat(wrap_tokens, dim=1)
|
341 |
wrap_embeds = torch.cat(wrap_embeds, dim=1)
|
|
|
342 |
wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
|
343 |
|
344 |
wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
|
345 |
|
346 |
+
temp_embeds.append(wrap_embeds)
|
347 |
+
temp_im_mask.append(wrap_im_mask)
|
348 |
+
temp_tars.append(wrap_target)
|
349 |
+
|
350 |
+
temp_max_len = np.max([i.shape[1] for i in temp_embeds])
|
351 |
+
temp_max_len = min(temp_max_len, self.max_length)
|
352 |
+
|
353 |
+
final_input, final_atts, final_tars, final_mask = [], [], [], []
|
354 |
+
pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id
|
355 |
+
pad = pad.long().to(self.device)
|
356 |
+
pad_emb = self.model.tok_embeddings(pad)
|
357 |
+
|
358 |
+
for idx in range(len(temp_embeds)):
|
359 |
+
temp_len = temp_embeds[idx].shape[1]
|
360 |
+
if temp_len >= temp_max_len:
|
361 |
+
final_input.append(temp_embeds[idx][:, :temp_max_len])
|
362 |
+
final_atts.append(torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device))
|
363 |
+
final_tars.append(temp_tars[idx][:, :temp_max_len])
|
364 |
+
final_mask.append(temp_im_mask[idx][:, :temp_max_len])
|
365 |
+
else:
|
366 |
+
final_input.append(torch.cat([temp_embeds[idx], pad_emb.repeat(1, temp_max_len-temp_len, 1)], dim=1))
|
367 |
+
final_atts.append(torch.cat([torch.ones(1, temp_len), torch.zeros(1, temp_max_len-temp_len)], dim=1).to(wrap_target.dtype).to(self.device))
|
368 |
+
final_tars.append(torch.cat([temp_tars[idx], (torch.ones(1, temp_max_len-temp_len)*-100).to(wrap_target.dtype).to(self.device)], dim=1))
|
369 |
+
final_mask.append(torch.cat([temp_im_mask[idx], (torch.zeros(1, temp_max_len-temp_len)).to(wrap_target.dtype).to(self.device)], dim=1))
|
370 |
|
371 |
+
inputs_embeds = torch.cat(final_input, dim=0)
|
372 |
+
attention_mask = torch.cat(final_atts, dim=0)
|
373 |
+
targets = torch.cat(final_tars, dim=0)
|
374 |
+
im_mask = torch.cat(final_mask, dim=0)
|
375 |
|
376 |
+
return inputs_embeds, attention_mask, targets, im_mask
|
|
|
|
|
|
|
|
|
377 |
|
378 |
def mask_human_targets(self, input_ids, pure=False):
|
379 |
target_batch = []
|
|
|
440 |
text = samples['text_input']
|
441 |
# encode image
|
442 |
if has_img:
|
443 |
+
image = samples['image'][0]
|
444 |
+
bs = len(samples['text_input'][0])
|
445 |
+
image_nums = []
|
446 |
+
temp_image = []
|
447 |
+
for im in image:
|
448 |
+
if type(im) is list:
|
449 |
+
image_nums.append(len(im))
|
450 |
+
temp_image.extend(im)
|
451 |
+
else:
|
452 |
+
image_nums.append(1)
|
453 |
+
temp_image.append(im)
|
454 |
+
image = temp_image
|
455 |
+
assert type(image) is list and len(image_nums) == bs
|
456 |
+
|
457 |
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
|
458 |
+
image, text, image_nums)
|
459 |
else:
|
460 |
to_regress_tokens, targets = self.text2emb(
|
461 |
text, add_special_tokens=True)
|