g-h-chen commited on
Commit
8f83af3
1 Parent(s): 06f7d63

upload modeling_llava_stablelm_1_6b.py

Browse files
Files changed (1) hide show
  1. modeling_llava_stablelm_1_6b.py +239 -0
modeling_llava_stablelm_1_6b.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+ import warnings
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+
23
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, PretrainedConfig
24
+ # StableLMEpochConfig, StableLMEpochModel, StableLMEpochForCausalLM
25
+ from transformers.modeling_utils import cached_file, CONFIG_NAME, extract_commit_hash, is_peft_available, find_adapter_config_file, json, os
26
+ from transformers.models.auto.auto_factory import _BaseAutoModelClass, _get_model_class
27
+ from transformers.dynamic_module_utils import resolve_trust_remote_code, get_class_from_dynamic_module
28
+
29
+
30
+ from transformers.modeling_outputs import CausalLMOutputWithPast
31
+
32
+ import pdb
33
+
34
+
35
+ import sys
36
+ from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
37
+ from .modeling_stablelm_epoch import StableLMEpochForCausalLM, StableLMEpochModel, StableLMEpochConfig
38
+ from .generation_utils import build_allava_input
39
+
40
+
41
+ ################ stableLM ###############################
42
+
43
+ class LlavaStableLM_1_6bConfig(StableLMEpochConfig):
44
+ model_type = "llava_stablelm_1_6b"
45
+
46
+ # class LlavaStableLMModel(LlavaMetaModel, AutoModel):
47
+ class LlavaStableLMModel(LlavaMetaModel, StableLMEpochModel):
48
+ config_class = LlavaStableLM_1_6bConfig
49
+
50
+ def __init__(self, config: AutoConfig):
51
+ super(LlavaStableLMModel, self).__init__(config)
52
+
53
+
54
+
55
+ class LlavaStableLM_1_6bForCausalLM(StableLMEpochForCausalLM, LlavaMetaForCausalLM):
56
+ config_class = LlavaStableLM_1_6bConfig
57
+
58
+
59
+ def __init__(self, config, init_vision_encoder_from_ckpt=True):
60
+ config._attn_implementation = "flash_attention_2"
61
+
62
+ super(StableLMEpochForCausalLM, self).__init__(config)
63
+
64
+ self.model = LlavaStableLMModel(config)
65
+ if hasattr(self.model, '_use_flash_attention_2'):
66
+ assert self.model._use_flash_attention_2, 'flash attn is not enabled. check it out!'
67
+ # self.pretraining_tp = config.pretraining_tp
68
+ self.vocab_size = config.vocab_size
69
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
70
+
71
+ if init_vision_encoder_from_ckpt:
72
+ vision_tower = self.get_vision_tower()
73
+ print(f'loading from CLIP first. This should only be used at inference!!!')
74
+ vision_tower.load_model()
75
+
76
+ # Initialize weights and apply final processing
77
+ self.post_init()
78
+
79
+
80
+ def get_model(self):
81
+ return self.model
82
+
83
+ def get_tokenizer(self):
84
+ return self.tokenizer
85
+
86
+ def get_processor(self):
87
+ return self.model.vision_tower.image_processor
88
+
89
+ def forward(
90
+ self,
91
+ input_ids: torch.LongTensor = None,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ position_ids: Optional[torch.LongTensor] = None,
94
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
95
+ inputs_embeds: Optional[torch.FloatTensor] = None,
96
+ labels: Optional[torch.LongTensor] = None,
97
+ use_cache: Optional[bool] = None,
98
+ output_attentions: Optional[bool] = None,
99
+ output_hidden_states: Optional[bool] = None,
100
+ images: Optional[torch.FloatTensor] = None,
101
+ return_dict: Optional[bool] = None,
102
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
103
+
104
+ if inputs_embeds is None:
105
+ (
106
+ input_ids,
107
+ position_ids,
108
+ attention_mask,
109
+ past_key_values,
110
+ inputs_embeds,
111
+ labels
112
+ # ) = self.prepare_inputs_labels_for_multimodal(
113
+ ) = self.prepare_inputs_labels_for_multimodal_new(
114
+ input_ids,
115
+ position_ids,
116
+ attention_mask,
117
+ past_key_values,
118
+ labels,
119
+ images
120
+ )
121
+
122
+ return super().forward(
123
+ input_ids=input_ids,
124
+ attention_mask=attention_mask,
125
+ position_ids=position_ids,
126
+ past_key_values=past_key_values,
127
+ inputs_embeds=inputs_embeds,
128
+ labels=labels,
129
+ use_cache=use_cache,
130
+ output_attentions=output_attentions,
131
+ output_hidden_states=output_hidden_states,
132
+ return_dict=return_dict
133
+ )
134
+
135
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
136
+ images = kwargs.pop("images", None)
137
+ _inputs = super().prepare_inputs_for_generation(
138
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
139
+ )
140
+ if images is not None:
141
+ _inputs['images'] = images
142
+ return _inputs
143
+
144
+ @torch.no_grad()
145
+ def generate(
146
+ self,
147
+ inputs: Optional[torch.Tensor] = None,
148
+ images: Optional[torch.Tensor] = None,
149
+ **kwargs,
150
+ ) :
151
+ position_ids = kwargs.pop("position_ids", None)
152
+ attention_mask = kwargs.pop("attention_mask", None)
153
+ if "inputs_embeds" in kwargs:
154
+ raise NotImplementedError("`inputs_embeds` is not supported")
155
+
156
+ if images is not None:
157
+ (
158
+ inputs,
159
+ position_ids,
160
+ attention_mask,
161
+ _,
162
+ inputs_embeds,
163
+ _
164
+ ) = self.prepare_inputs_labels_for_multimodal_new(
165
+ inputs,
166
+ position_ids,
167
+ attention_mask,
168
+ None,
169
+ None,
170
+ images
171
+ )
172
+ else:
173
+ inputs_embeds = self.get_model().embed_tokens(inputs)
174
+
175
+ # print(inputs_embeds.shape)
176
+ return super().generate(
177
+ position_ids=None,
178
+ attention_mask=None,
179
+ inputs_embeds=inputs_embeds,
180
+ **kwargs
181
+ )
182
+
183
+
184
+ def chat(
185
+ self,
186
+ texts: Optional[str | list[list[str, str]]],
187
+ images: Optional[str | list[str]] = None,
188
+ history: Optional[list[str]] = None,
189
+ stream = False,
190
+ return_history = False,
191
+ **kwargs
192
+ ):
193
+ '''
194
+ texts: if `str`, then generate for a single round; if list[dict],
195
+ images: str (optional), local path to an image.
196
+ '''
197
+ use_cache = kwargs.pop('use_cache', True)
198
+
199
+
200
+ ############################
201
+ # merge history
202
+ ############################
203
+ input_ids, image_tensors, history = build_allava_input(
204
+ tokenizer = self.get_tokenizer(),
205
+ processor = self.get_processor(),
206
+ texts = texts,
207
+ images = images,
208
+ history=history,
209
+ return_history=return_history,
210
+ device = self.device
211
+ )
212
+
213
+ ############################
214
+ # generate response
215
+ ############################
216
+ # with torch.autocast(device_type='cuda'):
217
+ if 'cuda' in str(self.device):
218
+ device_type = 'cuda'
219
+ else:
220
+ device_type = 'cpu'
221
+
222
+ with torch.autocast(device_type=device_type, dtype=self.dtype):
223
+ output_ids = self.generate(
224
+ inputs=input_ids,
225
+ images=image_tensors,
226
+ use_cache=use_cache,
227
+ **kwargs)
228
+
229
+ answer = self.get_tokenizer().decode(output_ids[0, :], skip_special_tokens=True).strip()
230
+
231
+ if return_history:
232
+ history[-1][-1] = answer
233
+ return answer, history
234
+ return answer
235
+
236
+
237
+ AutoConfig.register("llava_stablelm_1_6b", LlavaStableLM_1_6bConfig)
238
+ # AutoConfig.register("stablelm_epoch", LlavaStableLMConfig)
239
+ AutoModelForCausalLM.register(LlavaStableLM_1_6bConfig, LlavaStableLM_1_6bForCausalLM)