Xiangtai commited on
Commit
dc5440e
·
verified ·
1 Parent(s): 90238cb

Upload modeling_sa2va_chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_sa2va_chat.py +130 -106
modeling_sa2va_chat.py CHANGED
@@ -594,116 +594,137 @@ class Sa2VAChatModel(PreTrainedModel):
594
  assert tokenizer
595
  self.preparing_for_generation(tokenizer=tokenizer)
596
 
597
- input_dict = {}
598
- if video is not None:
599
- pixel_values = []
600
- extra_pixel_values = []
601
- ori_image_size = video[0].size
602
- for frame_idx, frame_image in enumerate(video):
603
- assert ori_image_size == frame_image.size
604
- g_image = np.array(frame_image) # for grounding
605
- g_image = self.extra_image_processor.apply_image(g_image)
606
- g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
607
- extra_pixel_values.append(g_image)
608
- if frame_idx < 5:
609
- img = self.transformer(frame_image)
610
- pixel_values.append(img)
611
-
612
- pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
613
- g_pixel_values = torch.stack([
614
- self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
615
- ]).to(self.torch_dtype)
616
- num_image_tokens = self.patch_token
617
- num_frames = 5
618
-
619
- input_dict['vp_overall_mask'] = None
620
  else:
621
- ori_image_size = image.size
622
-
623
- # prepare grounding images
624
- g_image = np.array(image) # for grounding
625
- g_image = self.extra_image_processor.apply_image(g_image)
626
- g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
627
- extra_pixel_values = [g_pixel_values]
628
- g_pixel_values = torch.stack([
629
- self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
630
- ]).to(self.torch_dtype)
 
 
 
 
 
 
 
 
 
 
 
631
 
632
- images = dynamic_preprocess(image, self.min_dynamic_patch,
633
- self.max_dynamic_patch,
634
- self.image_size, self.use_thumbnail)
635
-
636
- if mask_prompts is not None:
637
- vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
638
- input_dict['vp_overall_mask'] = vp_overall_mask
639
- else:
640
  input_dict['vp_overall_mask'] = None
 
 
641
 
642
- pixel_values = [self.transformer(image) for image in images]
643
- pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
644
- num_image_tokens = pixel_values.shape[0] * self.patch_token
645
- num_frames = 1
646
- input_dict['g_pixel_values'] = g_pixel_values
647
- input_dict['pixel_values'] = pixel_values
648
-
649
-
650
- if mask_prompts is not None:
651
- # reshape mask prompts to feature size
652
- mask_prompts = [torch.Tensor(item).to(pixel_values.device) for item in mask_prompts]
653
- mask_prompts = [F.interpolate(
654
- item.unsqueeze(0),
655
- size=(int(self.image_size // self.patch_size * self.downsample_ratio),
656
- int(self.image_size // self.patch_size * self.downsample_ratio)),
657
- mode='nearest').squeeze(0) for item in mask_prompts]
658
- region_pixels = []
659
- for mask_prompt in mask_prompts[0]:
660
- region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
661
-
662
- vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
663
- for i in range(len(mask_prompts[0])):
664
- vp_token_str = vp_token_str + \
665
- f"region{i + 1}" + self.VP_START_TOKEN + \
666
- self.IMG_CONTEXT_TOKEN * region_pixels[i] + \
667
- self.VP_END_TOKEN
668
- if i == len(mask_prompts[0]) - 1:
669
- vp_token_str = vp_token_str + '.\n'
670
  else:
671
- vp_token_str = vp_token_str + ', '
672
- else:
673
- vp_token_str = ''
674
-
675
- image_token_str = f'{self.IMG_START_TOKEN}' \
676
- f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
677
- f'{self.IMG_END_TOKEN}'
678
- image_token_str = image_token_str + '\n'
679
- image_token_str = image_token_str * num_frames
680
- image_token_str = image_token_str.strip()
681
-
682
- ret_masks = []
683
-
684
- if '<image>' in text or mask_prompts is not None:
685
- assert past_text is None or len(past_text) == 0
686
- text = text.replace('<image>', image_token_str + vp_token_str)
687
- input_text = ''
688
- input_text += self.template['INSTRUCTION'].format(
689
- input=text, round=1, bot_name=self.bot_name)
690
- input_text = past_text + input_text
691
- ids = self.tokenizer.encode(input_text)
692
- ret_past_text = self.tokenizer.decode(ids)
693
- ids = torch.tensor(ids).cuda().unsqueeze(0)
694
 
695
- attention_mask = torch.ones_like(ids, dtype=torch.bool)
 
 
 
 
 
696
 
697
- mm_inputs = {
698
- 'pixel_values': input_dict['pixel_values'],
699
- 'input_ids': ids,
700
- 'attention_mask': attention_mask,
701
- 'position_ids': None,
702
- 'past_key_values': None,
703
- 'labels': None,
704
- 'prompt_masks': mask_prompts,
705
- 'vp_overall_mask': input_dict['vp_overall_mask'],
706
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
 
708
  generate_output = self.generate(
709
  **mm_inputs,
@@ -716,8 +737,10 @@ class Sa2VAChatModel(PreTrainedModel):
716
  )
717
  predict = self.tokenizer.decode(
718
  generate_output.sequences[0], skip_special_tokens=False).strip()
719
- ret_past_text = ret_past_text + self.tokenizer.decode(
720
- generate_output.sequences[0], skip_special_tokens=False)
 
 
721
  # if have seg result, find the seg hidden states
722
  hidden_states = generate_output.hidden_states
723
  last_hidden_states = [item[-1][0] for item in hidden_states]
@@ -739,7 +762,8 @@ class Sa2VAChatModel(PreTrainedModel):
739
  masks = masks.sigmoid() > 0.5
740
  masks = masks.cpu().numpy()
741
  ret_masks.append(masks)
742
- return {'prediction': predict, 'prediction_masks': ret_masks, "past_text": ret_past_text}
 
743
 
744
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
745
  seg_mask = output_ids == seg_id
 
594
  assert tokenizer
595
  self.preparing_for_generation(tokenizer=tokenizer)
596
 
597
+ if image is None and video is None and '<image>' not in past_text:
598
+ text = text.replace('<image>', "")
599
+ input_text = ''
600
+ input_text += self.template['INSTRUCTION'].format(
601
+ input=text, round=1, bot_name=self.bot_name)
602
+ input_text = past_text + input_text
603
+ ids = self.tokenizer.encode(input_text)
604
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
605
+
606
+ attention_mask = torch.ones_like(ids, dtype=torch.bool)
607
+
608
+ mm_inputs = {
609
+ 'pixel_values': None,
610
+ 'input_ids': ids,
611
+ 'attention_mask': attention_mask,
612
+ 'position_ids': None,
613
+ 'past_key_values': None,
614
+ 'labels': None,
615
+ 'prompt_masks': None,
616
+ 'vp_overall_mask': None,
617
+ }
618
+ ret_masks = []
 
619
  else:
620
+ input_dict = {}
621
+ if video is not None:
622
+ pixel_values = []
623
+ extra_pixel_values = []
624
+ ori_image_size = video[0].size
625
+ for frame_idx, frame_image in enumerate(video):
626
+ assert ori_image_size == frame_image.size
627
+ g_image = np.array(frame_image) # for grounding
628
+ g_image = self.extra_image_processor.apply_image(g_image)
629
+ g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
630
+ extra_pixel_values.append(g_image)
631
+ if frame_idx < 5:
632
+ img = self.transformer(frame_image)
633
+ pixel_values.append(img)
634
+
635
+ pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
636
+ g_pixel_values = torch.stack([
637
+ self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
638
+ ]).to(self.torch_dtype)
639
+ num_image_tokens = self.patch_token
640
+ num_frames = len(pixel_values)
641
 
 
 
 
 
 
 
 
 
642
  input_dict['vp_overall_mask'] = None
643
+ else:
644
+ ori_image_size = image.size
645
 
646
+ # prepare grounding images
647
+ g_image = np.array(image) # for grounding
648
+ g_image = self.extra_image_processor.apply_image(g_image)
649
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
650
+ extra_pixel_values = [g_pixel_values]
651
+ g_pixel_values = torch.stack([
652
+ self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
653
+ ]).to(self.torch_dtype)
654
+
655
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
656
+ self.max_dynamic_patch,
657
+ self.image_size, self.use_thumbnail)
658
+
659
+ if mask_prompts is not None:
660
+ vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
661
+ input_dict['vp_overall_mask'] = vp_overall_mask
 
 
 
 
 
 
 
 
 
 
 
 
662
  else:
663
+ input_dict['vp_overall_mask'] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
 
665
+ pixel_values = [self.transformer(image) for image in images]
666
+ pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
667
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
668
+ num_frames = 1
669
+ input_dict['g_pixel_values'] = g_pixel_values
670
+ input_dict['pixel_values'] = pixel_values
671
 
672
+ if mask_prompts is not None:
673
+ # reshape mask prompts to feature size
674
+ mask_prompts = [torch.Tensor(item).to(pixel_values.device) for item in mask_prompts]
675
+ mask_prompts = [F.interpolate(
676
+ item.unsqueeze(0),
677
+ size=(int(self.image_size // self.patch_size * self.downsample_ratio),
678
+ int(self.image_size // self.patch_size * self.downsample_ratio)),
679
+ mode='nearest').squeeze(0) for item in mask_prompts]
680
+ region_pixels = []
681
+ for mask_prompt in mask_prompts[0]:
682
+ region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
683
+
684
+ vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
685
+ for i in range(len(mask_prompts[0])):
686
+ vp_token_str = vp_token_str + \
687
+ f"region{i + 1}" + self.VP_START_TOKEN + \
688
+ self.IMG_CONTEXT_TOKEN * region_pixels[i] + \
689
+ self.VP_END_TOKEN
690
+ if i == len(mask_prompts[0]) - 1:
691
+ vp_token_str = vp_token_str + '.\n'
692
+ else:
693
+ vp_token_str = vp_token_str + ', '
694
+ else:
695
+ vp_token_str = ''
696
+
697
+ image_token_str = f'{self.IMG_START_TOKEN}' \
698
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
699
+ f'{self.IMG_END_TOKEN}'
700
+ image_token_str = image_token_str + '\n'
701
+ image_token_str = image_token_str * num_frames
702
+ image_token_str = image_token_str.strip()
703
+
704
+ ret_masks = []
705
+
706
+ if '<image>' in text or mask_prompts is not None:
707
+ assert past_text is None or len(past_text) == 0
708
+ text = text.replace('<image>', image_token_str + vp_token_str)
709
+ input_text = ''
710
+ input_text += self.template['INSTRUCTION'].format(
711
+ input=text, round=1, bot_name=self.bot_name)
712
+ input_text = past_text + input_text
713
+ ids = self.tokenizer.encode(input_text)
714
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
715
+
716
+ attention_mask = torch.ones_like(ids, dtype=torch.bool)
717
+
718
+ mm_inputs = {
719
+ 'pixel_values': input_dict['pixel_values'],
720
+ 'input_ids': ids,
721
+ 'attention_mask': attention_mask,
722
+ 'position_ids': None,
723
+ 'past_key_values': None,
724
+ 'labels': None,
725
+ 'prompt_masks': mask_prompts,
726
+ 'vp_overall_mask': input_dict['vp_overall_mask'],
727
+ }
728
 
729
  generate_output = self.generate(
730
  **mm_inputs,
 
737
  )
738
  predict = self.tokenizer.decode(
739
  generate_output.sequences[0], skip_special_tokens=False).strip()
740
+
741
+ if image is None and video is None and '<image>' not in past_text:
742
+ return {'prediction': predict, 'prediction_masks': ret_masks, }
743
+
744
  # if have seg result, find the seg hidden states
745
  hidden_states = generate_output.hidden_states
746
  last_hidden_states = [item[-1][0] for item in hidden_states]
 
762
  masks = masks.sigmoid() > 0.5
763
  masks = masks.cpu().numpy()
764
  ret_masks.append(masks)
765
+
766
+ return {'prediction': predict, 'prediction_masks': ret_masks,}
767
 
768
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
769
  seg_mask = output_ids == seg_id