Upload modeling_sa2va_chat.py with huggingface_hub
Browse files- modeling_sa2va_chat.py +130 -106
modeling_sa2va_chat.py
CHANGED
@@ -594,116 +594,137 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
594 |
assert tokenizer
|
595 |
self.preparing_for_generation(tokenizer=tokenizer)
|
596 |
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
input_dict['vp_overall_mask'] = None
|
620 |
else:
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
|
632 |
-
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
633 |
-
self.max_dynamic_patch,
|
634 |
-
self.image_size, self.use_thumbnail)
|
635 |
-
|
636 |
-
if mask_prompts is not None:
|
637 |
-
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
|
638 |
-
input_dict['vp_overall_mask'] = vp_overall_mask
|
639 |
-
else:
|
640 |
input_dict['vp_overall_mask'] = None
|
|
|
|
|
641 |
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
region_pixels = []
|
659 |
-
for mask_prompt in mask_prompts[0]:
|
660 |
-
region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
|
661 |
-
|
662 |
-
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
663 |
-
for i in range(len(mask_prompts[0])):
|
664 |
-
vp_token_str = vp_token_str + \
|
665 |
-
f"region{i + 1}" + self.VP_START_TOKEN + \
|
666 |
-
self.IMG_CONTEXT_TOKEN * region_pixels[i] + \
|
667 |
-
self.VP_END_TOKEN
|
668 |
-
if i == len(mask_prompts[0]) - 1:
|
669 |
-
vp_token_str = vp_token_str + '.\n'
|
670 |
else:
|
671 |
-
|
672 |
-
else:
|
673 |
-
vp_token_str = ''
|
674 |
-
|
675 |
-
image_token_str = f'{self.IMG_START_TOKEN}' \
|
676 |
-
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
677 |
-
f'{self.IMG_END_TOKEN}'
|
678 |
-
image_token_str = image_token_str + '\n'
|
679 |
-
image_token_str = image_token_str * num_frames
|
680 |
-
image_token_str = image_token_str.strip()
|
681 |
-
|
682 |
-
ret_masks = []
|
683 |
-
|
684 |
-
if '<image>' in text or mask_prompts is not None:
|
685 |
-
assert past_text is None or len(past_text) == 0
|
686 |
-
text = text.replace('<image>', image_token_str + vp_token_str)
|
687 |
-
input_text = ''
|
688 |
-
input_text += self.template['INSTRUCTION'].format(
|
689 |
-
input=text, round=1, bot_name=self.bot_name)
|
690 |
-
input_text = past_text + input_text
|
691 |
-
ids = self.tokenizer.encode(input_text)
|
692 |
-
ret_past_text = self.tokenizer.decode(ids)
|
693 |
-
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
694 |
|
695 |
-
|
|
|
|
|
|
|
|
|
|
|
696 |
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
|
708 |
generate_output = self.generate(
|
709 |
**mm_inputs,
|
@@ -716,8 +737,10 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
716 |
)
|
717 |
predict = self.tokenizer.decode(
|
718 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
719 |
-
|
720 |
-
|
|
|
|
|
721 |
# if have seg result, find the seg hidden states
|
722 |
hidden_states = generate_output.hidden_states
|
723 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
@@ -739,7 +762,8 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
739 |
masks = masks.sigmoid() > 0.5
|
740 |
masks = masks.cpu().numpy()
|
741 |
ret_masks.append(masks)
|
742 |
-
|
|
|
743 |
|
744 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
745 |
seg_mask = output_ids == seg_id
|
|
|
594 |
assert tokenizer
|
595 |
self.preparing_for_generation(tokenizer=tokenizer)
|
596 |
|
597 |
+
if image is None and video is None and '<image>' not in past_text:
|
598 |
+
text = text.replace('<image>', "")
|
599 |
+
input_text = ''
|
600 |
+
input_text += self.template['INSTRUCTION'].format(
|
601 |
+
input=text, round=1, bot_name=self.bot_name)
|
602 |
+
input_text = past_text + input_text
|
603 |
+
ids = self.tokenizer.encode(input_text)
|
604 |
+
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
605 |
+
|
606 |
+
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
607 |
+
|
608 |
+
mm_inputs = {
|
609 |
+
'pixel_values': None,
|
610 |
+
'input_ids': ids,
|
611 |
+
'attention_mask': attention_mask,
|
612 |
+
'position_ids': None,
|
613 |
+
'past_key_values': None,
|
614 |
+
'labels': None,
|
615 |
+
'prompt_masks': None,
|
616 |
+
'vp_overall_mask': None,
|
617 |
+
}
|
618 |
+
ret_masks = []
|
|
|
619 |
else:
|
620 |
+
input_dict = {}
|
621 |
+
if video is not None:
|
622 |
+
pixel_values = []
|
623 |
+
extra_pixel_values = []
|
624 |
+
ori_image_size = video[0].size
|
625 |
+
for frame_idx, frame_image in enumerate(video):
|
626 |
+
assert ori_image_size == frame_image.size
|
627 |
+
g_image = np.array(frame_image) # for grounding
|
628 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
629 |
+
g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
630 |
+
extra_pixel_values.append(g_image)
|
631 |
+
if frame_idx < 5:
|
632 |
+
img = self.transformer(frame_image)
|
633 |
+
pixel_values.append(img)
|
634 |
+
|
635 |
+
pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
|
636 |
+
g_pixel_values = torch.stack([
|
637 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
638 |
+
]).to(self.torch_dtype)
|
639 |
+
num_image_tokens = self.patch_token
|
640 |
+
num_frames = len(pixel_values)
|
641 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
642 |
input_dict['vp_overall_mask'] = None
|
643 |
+
else:
|
644 |
+
ori_image_size = image.size
|
645 |
|
646 |
+
# prepare grounding images
|
647 |
+
g_image = np.array(image) # for grounding
|
648 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
649 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
|
650 |
+
extra_pixel_values = [g_pixel_values]
|
651 |
+
g_pixel_values = torch.stack([
|
652 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
653 |
+
]).to(self.torch_dtype)
|
654 |
+
|
655 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
656 |
+
self.max_dynamic_patch,
|
657 |
+
self.image_size, self.use_thumbnail)
|
658 |
+
|
659 |
+
if mask_prompts is not None:
|
660 |
+
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
|
661 |
+
input_dict['vp_overall_mask'] = vp_overall_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
else:
|
663 |
+
input_dict['vp_overall_mask'] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
|
665 |
+
pixel_values = [self.transformer(image) for image in images]
|
666 |
+
pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
|
667 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
668 |
+
num_frames = 1
|
669 |
+
input_dict['g_pixel_values'] = g_pixel_values
|
670 |
+
input_dict['pixel_values'] = pixel_values
|
671 |
|
672 |
+
if mask_prompts is not None:
|
673 |
+
# reshape mask prompts to feature size
|
674 |
+
mask_prompts = [torch.Tensor(item).to(pixel_values.device) for item in mask_prompts]
|
675 |
+
mask_prompts = [F.interpolate(
|
676 |
+
item.unsqueeze(0),
|
677 |
+
size=(int(self.image_size // self.patch_size * self.downsample_ratio),
|
678 |
+
int(self.image_size // self.patch_size * self.downsample_ratio)),
|
679 |
+
mode='nearest').squeeze(0) for item in mask_prompts]
|
680 |
+
region_pixels = []
|
681 |
+
for mask_prompt in mask_prompts[0]:
|
682 |
+
region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
|
683 |
+
|
684 |
+
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
685 |
+
for i in range(len(mask_prompts[0])):
|
686 |
+
vp_token_str = vp_token_str + \
|
687 |
+
f"region{i + 1}" + self.VP_START_TOKEN + \
|
688 |
+
self.IMG_CONTEXT_TOKEN * region_pixels[i] + \
|
689 |
+
self.VP_END_TOKEN
|
690 |
+
if i == len(mask_prompts[0]) - 1:
|
691 |
+
vp_token_str = vp_token_str + '.\n'
|
692 |
+
else:
|
693 |
+
vp_token_str = vp_token_str + ', '
|
694 |
+
else:
|
695 |
+
vp_token_str = ''
|
696 |
+
|
697 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
698 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
699 |
+
f'{self.IMG_END_TOKEN}'
|
700 |
+
image_token_str = image_token_str + '\n'
|
701 |
+
image_token_str = image_token_str * num_frames
|
702 |
+
image_token_str = image_token_str.strip()
|
703 |
+
|
704 |
+
ret_masks = []
|
705 |
+
|
706 |
+
if '<image>' in text or mask_prompts is not None:
|
707 |
+
assert past_text is None or len(past_text) == 0
|
708 |
+
text = text.replace('<image>', image_token_str + vp_token_str)
|
709 |
+
input_text = ''
|
710 |
+
input_text += self.template['INSTRUCTION'].format(
|
711 |
+
input=text, round=1, bot_name=self.bot_name)
|
712 |
+
input_text = past_text + input_text
|
713 |
+
ids = self.tokenizer.encode(input_text)
|
714 |
+
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
715 |
+
|
716 |
+
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
717 |
+
|
718 |
+
mm_inputs = {
|
719 |
+
'pixel_values': input_dict['pixel_values'],
|
720 |
+
'input_ids': ids,
|
721 |
+
'attention_mask': attention_mask,
|
722 |
+
'position_ids': None,
|
723 |
+
'past_key_values': None,
|
724 |
+
'labels': None,
|
725 |
+
'prompt_masks': mask_prompts,
|
726 |
+
'vp_overall_mask': input_dict['vp_overall_mask'],
|
727 |
+
}
|
728 |
|
729 |
generate_output = self.generate(
|
730 |
**mm_inputs,
|
|
|
737 |
)
|
738 |
predict = self.tokenizer.decode(
|
739 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
740 |
+
|
741 |
+
if image is None and video is None and '<image>' not in past_text:
|
742 |
+
return {'prediction': predict, 'prediction_masks': ret_masks, }
|
743 |
+
|
744 |
# if have seg result, find the seg hidden states
|
745 |
hidden_states = generate_output.hidden_states
|
746 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
|
|
762 |
masks = masks.sigmoid() > 0.5
|
763 |
masks = masks.cpu().numpy()
|
764 |
ret_masks.append(masks)
|
765 |
+
|
766 |
+
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
767 |
|
768 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
769 |
seg_mask = output_ids == seg_id
|