Fabrice-TIERCELIN commited on
Commit
0ed7616
1 Parent(s): e752a35

Fix launch error

Browse files
llava/model/language_model/llava_llama.py CHANGED
@@ -1,140 +1,140 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- from torch.nn import CrossEntropyLoss
21
-
22
- from transformers import AutoConfig, AutoModelForCausalLM, \
23
- LlamaConfig, LlamaModel, LlamaForCausalLM
24
-
25
- from transformers.modeling_outputs import CausalLMOutputWithPast
26
-
27
- from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
-
29
-
30
- class LlavaConfig(LlamaConfig):
31
- model_type = "llava"
32
-
33
-
34
- class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
- config_class = LlavaConfig
36
-
37
- def __init__(self, config: LlamaConfig):
38
- super(LlavaLlamaModel, self).__init__(config)
39
-
40
-
41
- class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
- config_class = LlavaConfig
43
-
44
- def __init__(self, config):
45
- super(LlamaForCausalLM, self).__init__(config)
46
- self.model = LlavaLlamaModel(config)
47
-
48
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
-
50
- # Initialize weights and apply final processing
51
- self.post_init()
52
-
53
- def get_model(self):
54
- return self.model
55
-
56
- def forward(
57
- self,
58
- input_ids: torch.LongTensor = None,
59
- attention_mask: Optional[torch.Tensor] = None,
60
- past_key_values: Optional[List[torch.FloatTensor]] = None,
61
- inputs_embeds: Optional[torch.FloatTensor] = None,
62
- labels: Optional[torch.LongTensor] = None,
63
- use_cache: Optional[bool] = None,
64
- output_attentions: Optional[bool] = None,
65
- output_hidden_states: Optional[bool] = None,
66
- images: Optional[torch.FloatTensor] = None,
67
- return_dict: Optional[bool] = None,
68
- ) -> Union[Tuple, CausalLMOutputWithPast]:
69
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
- output_hidden_states = (
71
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
- )
73
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
-
75
- input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
-
77
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
- outputs = self.model(
79
- input_ids=input_ids,
80
- attention_mask=attention_mask,
81
- past_key_values=past_key_values,
82
- inputs_embeds=inputs_embeds,
83
- use_cache=use_cache,
84
- output_attentions=output_attentions,
85
- output_hidden_states=output_hidden_states,
86
- return_dict=return_dict
87
- )
88
-
89
- hidden_states = outputs[0]
90
- logits = self.lm_head(hidden_states)
91
-
92
- loss = None
93
- if labels is not None:
94
- # Shift so that tokens < n predict n
95
- shift_logits = logits[..., :-1, :].contiguous()
96
- shift_labels = labels[..., 1:].contiguous()
97
- # Flatten the tokens
98
- loss_fct = CrossEntropyLoss()
99
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
- shift_labels = shift_labels.view(-1)
101
- # Enable model/pipeline parallelism
102
- shift_labels = shift_labels.to(shift_logits.device)
103
- loss = loss_fct(shift_logits, shift_labels)
104
-
105
- if not return_dict:
106
- output = (logits,) + outputs[1:]
107
- return (loss,) + output if loss is not None else output
108
-
109
- return CausalLMOutputWithPast(
110
- loss=loss,
111
- logits=logits,
112
- past_key_values=outputs.past_key_values,
113
- hidden_states=outputs.hidden_states,
114
- attentions=outputs.attentions,
115
- )
116
-
117
- def prepare_inputs_for_generation(
118
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
- ):
120
- if past_key_values:
121
- input_ids = input_ids[:, -1:]
122
-
123
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
- if inputs_embeds is not None and past_key_values is None:
125
- model_inputs = {"inputs_embeds": inputs_embeds}
126
- else:
127
- model_inputs = {"input_ids": input_ids}
128
-
129
- model_inputs.update(
130
- {
131
- "past_key_values": past_key_values,
132
- "use_cache": kwargs.get("use_cache"),
133
- "attention_mask": attention_mask,
134
- "images": kwargs.get("images", None),
135
- }
136
- )
137
- return model_inputs
138
-
139
- AutoConfig.register("llava", LlavaConfig)
140
- AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(LlamaConfig):
31
+ model_type = "llava"
32
+
33
+
34
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(LlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = LlavaLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+
89
+ hidden_states = outputs[0]
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+ return model_inputs
138
+
139
+ AutoConfig.register("llava", LlavaConfig, exist_ok=True)
140
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)