Fabrice-TIERCELIN commited on
Commit
ccde449
·
verified ·
1 Parent(s): eedc21b

Upload 2 files

Browse files
llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(LlamaConfig):
31
+ model_type = "llava"
32
+
33
+
34
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(LlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = LlavaLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+
89
+ hidden_states = outputs[0]
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+ return model_inputs
138
+
139
+ AutoConfig.register("llava", LlavaConfig)
140
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple
17
+ import warnings
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import math
22
+
23
+ from transformers import AutoConfig, AutoModelForCausalLM
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaMPTConfig(MPTConfig):
31
+ model_type = "llava_mpt"
32
+
33
+
34
+ class LlavaMPTModel(LlavaMetaModel, MPTModel):
35
+ config_class = LlavaMPTConfig
36
+
37
+ def __init__(self, config: MPTConfig):
38
+ config.hidden_size = config.d_model
39
+ super(LlavaMPTModel, self).__init__(config)
40
+
41
+ def embed_tokens(self, x):
42
+ return self.wte(x)
43
+
44
+
45
+ class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
46
+ config_class = LlavaMPTConfig
47
+ supports_gradient_checkpointing = True
48
+
49
+ def __init__(self, config):
50
+ super(MPTForCausalLM, self).__init__(config)
51
+
52
+ if not config.tie_word_embeddings:
53
+ raise ValueError('MPTForCausalLM only supports tied word embeddings')
54
+ self.transformer = LlavaMPTModel(config)
55
+ self.logit_scale = None
56
+ if config.logit_scale is not None:
57
+ logit_scale = config.logit_scale
58
+ if isinstance(logit_scale, str):
59
+ if logit_scale == 'inv_sqrt_d_model':
60
+ logit_scale = 1 / math.sqrt(config.d_model)
61
+ else:
62
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
63
+ self.logit_scale = logit_scale
64
+
65
+ def get_model(self):
66
+ return self.transformer
67
+
68
+ def _set_gradient_checkpointing(self, module, value=False):
69
+ if isinstance(module, LlavaMPTModel):
70
+ module.gradient_checkpointing = value
71
+
72
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
73
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
74
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
75
+
76
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
77
+ outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
78
+ # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
79
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
80
+ if self.logit_scale is not None:
81
+ if self.logit_scale == 0:
82
+ warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
83
+ logits *= self.logit_scale
84
+ loss = None
85
+ if labels is not None:
86
+ labels = torch.roll(labels, shifts=-1)
87
+ labels[:, -1] = -100
88
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
89
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
90
+
91
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
92
+ if inputs_embeds is not None:
93
+ raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
94
+ attention_mask = kwargs['attention_mask'].bool()
95
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
96
+ raise NotImplementedError('MPT does not support generation with right padding.')
97
+ if self.transformer.attn_uses_sequence_id and self.training:
98
+ sequence_id = torch.zeros_like(input_ids[:1])
99
+ else:
100
+ sequence_id = None
101
+ if past_key_values is not None:
102
+ input_ids = input_ids[:, -1].unsqueeze(-1)
103
+ if self.transformer.prefix_lm:
104
+ prefix_mask = torch.ones_like(attention_mask)
105
+ if kwargs.get('use_cache') == False:
106
+ raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
107
+ else:
108
+ prefix_mask = None
109
+ return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
110
+
111
+
112
+ AutoConfig.register("llava_mpt", LlavaMPTConfig)
113
+ AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)