Text Generation
Transformers
PyTorch
mosaic_gpt
custom_code

Add labels into forward

#1
by i-gao - opened
Files changed (1) hide show
  1. mosaic_gpt.py +22 -3
mosaic_gpt.py CHANGED
@@ -238,6 +238,7 @@ class MosaicGPT(PreTrainedModel):
238
  input_ids: torch.LongTensor,
239
  past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
240
  attention_mask: Optional[torch.ByteTensor] = None,
 
241
  prefix_mask: Optional[torch.ByteTensor] = None,
242
  sequence_id: Optional[torch.LongTensor] = None,
243
  return_dict: Optional[bool] = None,
@@ -370,9 +371,27 @@ class MosaicGPT(PreTrainedModel):
370
  )
371
  logits *= self.logit_scale
372
 
373
- return CausalLMOutputWithPast(logits=logits,
374
- past_key_values=past_key_values,
375
- hidden_states=all_hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  # Param Initialization, needed for device='meta' fast initialization
378
  def param_init_fn(self, module):
 
238
  input_ids: torch.LongTensor,
239
  past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
240
  attention_mask: Optional[torch.ByteTensor] = None,
241
+ labels: Optional[torch.LongTensor] = None,
242
  prefix_mask: Optional[torch.ByteTensor] = None,
243
  sequence_id: Optional[torch.LongTensor] = None,
244
  return_dict: Optional[bool] = None,
 
371
  )
372
  logits *= self.logit_scale
373
 
374
+ # compute loss from logits
375
+ if labels is not None:
376
+ # Shift so that tokens < n predict n
377
+ shift_logits = logits[..., :-1, :].contiguous()
378
+ shift_labels = labels[..., 1:].contiguous()
379
+ # Flatten the tokens
380
+ loss_fct = nn.CrossEntropyLoss()
381
+ loss = loss_fct(
382
+ shift_logits.view(
383
+ -1, self.transformer.wte.num_embeddings
384
+ ),
385
+ shift_labels.view(-1),
386
+ )
387
+ return CausalLMOutputWithPast(loss=loss, logits=logits,
388
+ past_key_values=past_key_values,
389
+ hidden_states=all_hidden_states)
390
+
391
+ else:
392
+ return CausalLMOutputWithPast(logits=logits,
393
+ past_key_values=past_key_values,
394
+ hidden_states=all_hidden_states)
395
 
396
  # Param Initialization, needed for device='meta' fast initialization
397
  def param_init_fn(self, module):