File size: 7,293 Bytes
dd4cd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from typing import List, Optional, Tuple, Union

import torch.utils.checkpoint
from torch import nn
from transformers import GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel

from .configuration_phantom import PhantomConfig
from .modeling_intern_vit import InternVisionModel
from .modeling_internlm2 import InternLM2ForCausalLM

from utils.utils import *

class PhantomForCausalLM(PreTrainedModel):
    config_class = PhantomConfig
    main_input_name = 'pixel_values'
    _supports_flash_attn_2 = True
    _no_split_modules = ['InternVisionModel', 'InternLM2DecoderLayer']

    def __init__(self, config: PhantomConfig):
        super().__init__(config)
        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.template = config.template
        self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
        self.downsample_ratio = config.downsample_ratio

        self.vision_model = InternVisionModel(config.vision_config)
        self.language_model = InternLM2ForCausalLM(config.llm_config)

        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.llm_config.hidden_size

        self.vision_proj = nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size)
        )

        # prompt rule
        self.prompt_rule = {
                            "system_start": "<|im_start|>system\n",
                            "system_end": "<|im_end|>",
                            "user_start": "<|im_start|>user\n",
                            "user_end": "<|im_end|>",
                            "assistant_start": "<|im_start|>assistant\n",
                            "assistant_end": "<|im_end|>",
                            "test_start": "assistant\n",
                            "test_end": "<|im_end|>",
                            "split": "",
                            }

    def eval_process(
        self,
        inputs,
        tokenizer,
        data,
        device,
    ):
        batched_image=[]
        batched_qa_prompt=[]
        batched_phantom_position = []
        for _input in inputs:

            # making image prompt
            if 'image' in _input.keys() and _input['image'] != None:
                process_image = dynamic_preprocess(_input['image'].to(device))
                dynamic_process_image = torch.stack([dynamic_transform(image) for image in process_image]).to(device)
                img_token_number = dynamic_process_image.shape[0] * 256
                batched_image.append(dynamic_process_image)

            # make question and answer
            question =  _input['question']

            # make instruction (qa pair) and label 
            qa_prompt = make_instruction(question, data, self.prompt_rule)
            
            # adding image special tokens to question
            if 'image' in _input.keys():
                qa_prompt = qa_prompt.replace('<image>', '<img><IMG_CONTEXT></img>')

                # add bundle image tokens if it has <image> token
                qa_prompt = add_bundle_tokens(qa_prompt, '<IMG_CONTEXT>', img_token_number) 

            # phantom_position
            label = tokenizer(qa_prompt, return_tensors='pt', add_special_tokens=False).input_ids[0].to(device)
            phantom_position = torch.zeros_like(label)
            phantom_position[0] = 1

            # batched processing
            batched_qa_prompt.append(qa_prompt)
            batched_phantom_position.append(phantom_position.flip(dims=[0]))

        '''For Final Outputs'''
        qa_prompts = tokenizer(batched_qa_prompt, padding='longest', return_tensors="pt", add_special_tokens=False)

        # [1] input_ids
        input_ids = qa_prompts.input_ids.to(device)
  
        # [2] attention_mask
        attention_mask = qa_prompts.attention_mask.to(device)

        # [3] Phantom Position
        batched_phantom_position = torch.nn.utils.rnn.pad_sequence(batched_phantom_position, batch_first=True, padding_value=0).flip(dims=[1]) # padding left

        if len(batched_image):
            return {"input_ids": input_ids, 
                    "attention_mask": attention_mask, 
                    "pixel_values": torch.cat(batched_image, dim=0).to(device),
                    "phantom_position": batched_phantom_position.bool()
                    }
        else:
            return {"input_ids": input_ids, 
                    "attention_mask": attention_mask, 
                    "phantom_position": batched_phantom_position.bool()
                    }

    def extract_feature(self, pixel_values):
        vit_embeds = self.vision_model(
            pixel_values=pixel_values,
            output_hidden_states=False,
            return_dict=True).last_hidden_state
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        vit_embeds = self.vision_proj(vit_embeds)
        return vit_embeds

    @torch.no_grad()
    def generate(
            self,
            pixel_values: Optional[torch.FloatTensor] = None,
            input_ids: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            phantom_position: torch.BoolTensor = None,
            generation_config: Optional[GenerationConfig] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            **generate_kwargs,
    ) -> torch.LongTensor:
        
        if pixel_values is not None:
            vit_embeds = self.extract_feature(pixel_values.to(torch.bfloat16))
            input_embeds = self.language_model.get_input_embeddings()(input_ids)
            B, N, C = input_embeds.shape
            input_embeds = input_embeds.reshape(B * N, C)

            input_ids = input_ids.reshape(B * N)
            selected = (input_ids == self.config.image_token_index)
            assert selected.sum() != 0
            input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

            input_embeds = input_embeds.reshape(B, N, C)
        else:
            input_embeds = self.language_model.get_input_embeddings()(input_ids)

        outputs = self.language_model.generate(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            phantom_position=phantom_position,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            use_cache=True,
            pad_token_id=self.config.eos_token_id,
            eos_token_id=self.config.eos_token_id,
            **generate_kwargs,
        )

        return outputs