Merge cekal/mpt-7b-peft-compatible

#42
by muelletm - opened
Files changed (1) hide show
  1. modeling_mpt.py +68 -8
modeling_mpt.py CHANGED
@@ -23,12 +23,19 @@ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
 
 
 
 
 
 
26
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
29
  def __init__(self, config: MPTConfig):
30
  config._validate_config()
31
  super().__init__(config)
 
32
  self.attn_impl = config.attn_config['attn_impl']
33
  self.prefix_lm = config.attn_config['prefix_lm']
34
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -127,11 +134,40 @@ class MPTModel(MPTPreTrainedModel):
127
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
128
  return attn_bias
129
 
130
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
131
  return_dict = return_dict if return_dict is not None else self.config.return_dict
132
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  if attention_mask is not None:
134
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
135
  if prefix_mask is not None:
136
  prefix_mask = prefix_mask.bool()
137
  if not return_dict:
@@ -147,9 +183,8 @@ class MPTModel(MPTPreTrainedModel):
147
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
148
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
149
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
150
- S = input_ids.size(1)
151
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
152
- tok_emb = self.wte(input_ids)
153
  if self.alibi:
154
  x = tok_emb
155
  else:
@@ -161,7 +196,7 @@ class MPTModel(MPTPreTrainedModel):
161
  if S + past_position > self.config.max_seq_len:
162
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
163
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
164
- if attention_mask is not None:
165
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
166
  pos_emb = self.wpe(pos)
167
  x = tok_emb + pos_emb
@@ -180,7 +215,27 @@ class MPTModel(MPTPreTrainedModel):
180
  assert all_hidden_states is not None
181
  all_hidden_states = all_hidden_states + (x,)
182
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
183
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
186
  x = self.norm_f(x)
@@ -231,11 +286,16 @@ class MPTForCausalLM(MPTPreTrainedModel):
231
  def get_decoder(self):
232
  return self.transformer
233
 
234
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
235
  return_dict = return_dict if return_dict is not None else self.config.return_dict
236
  use_cache = use_cache if use_cache is not None else self.config.use_cache
237
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
238
- logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
 
 
 
 
 
239
  if self.logit_scale is not None:
240
  if self.logit_scale == 0:
241
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
+ _no_split_modules = ["MPTBlock"]
27
+ supports_gradient_checkpointing = True
28
+
29
+ def _set_gradient_checkpointing(self, module, value=False):
30
+ if isinstance(module, MPTModel):
31
+ module.gradient_checkpointing = value
32
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
35
  def __init__(self, config: MPTConfig):
36
  config._validate_config()
37
  super().__init__(config)
38
+ self.gradient_checkpointing = False
39
  self.attn_impl = config.attn_config['attn_impl']
40
  self.prefix_lm = config.attn_config['prefix_lm']
41
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
134
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
135
  return attn_bias
136
 
137
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
138
  return_dict = return_dict if return_dict is not None else self.config.return_dict
139
  use_cache = use_cache if use_cache is not None else self.config.use_cache
140
+ if self.gradient_checkpointing and self.training:
141
+ if use_cache:
142
+ use_cache = False
143
+ if input_ids is not None and inputs_embeds is not None:
144
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
145
+ elif input_ids is not None:
146
+ batch_size, seq_length = input_ids.shape
147
+ elif inputs_embeds is not None:
148
+ batch_size, seq_length, _ = inputs_embeds.shape
149
+ else:
150
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
151
+
152
+ seq_length_with_past = seq_length
153
+ past_key_values_length = 0
154
+
155
+ if past_key_values is not None:
156
+ past_key_values_length = past_key_values[0][0].shape[2]
157
+ seq_length_with_past = seq_length_with_past + past_key_values_length
158
+
159
  if attention_mask is not None:
160
  attention_mask = attention_mask.bool()
161
+ else:
162
+ attention_mask = torch.ones(
163
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
164
+ )
165
+
166
+ if inputs_embeds is None:
167
+ tok_emb = self.wte(input_ids)
168
+ else:
169
+ tok_emb = inputs_embeds
170
+
171
  if prefix_mask is not None:
172
  prefix_mask = prefix_mask.bool()
173
  if not return_dict:
 
183
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
184
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
185
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
186
+ S = seq_length
187
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
 
188
  if self.alibi:
189
  x = tok_emb
190
  else:
 
196
  if S + past_position > self.config.max_seq_len:
197
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
198
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
199
+ if attention_mask is not None and not self.training:
200
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
201
  pos_emb = self.wpe(pos)
202
  x = tok_emb + pos_emb
 
215
  assert all_hidden_states is not None
216
  all_hidden_states = all_hidden_states + (x,)
217
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
218
+
219
+ if self.gradient_checkpointing and self.training:
220
+
221
+ def create_custom_forward(module):
222
+ def custom_forward(*inputs):
223
+ # None for past_key_value
224
+ return module(*inputs)
225
+
226
+ return custom_forward
227
+
228
+ (x, past_key_value) = torch.utils.checkpoint.checkpoint(
229
+ create_custom_forward(block),
230
+ x,
231
+ past_key_value,
232
+ attn_bias,
233
+ attention_mask,
234
+ self.is_causal,
235
+ )
236
+ else:
237
+ (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
238
+
239
  if past_key_values is not None:
240
  past_key_values[b_idx] = past_key_value
241
  x = self.norm_f(x)
 
286
  def get_decoder(self):
287
  return self.transformer
288
 
289
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
290
  return_dict = return_dict if return_dict is not None else self.config.return_dict
291
  use_cache = use_cache if use_cache is not None else self.config.use_cache
292
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
293
+
294
+ last_hidden_state = outputs.last_hidden_state
295
+ if self.model_parallel:
296
+ last_hidden_state = last_hidden_state.to(self.transformer.wte.weight.device)
297
+ logits = F.linear(last_hidden_state, self.transformer.wte.weight)
298
+
299
  if self.logit_scale is not None:
300
  if self.logit_scale == 0:
301
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')