Raghavan commited on
Commit
8657564
1 Parent(s): f65ef53

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +5 -2
modeling_indictrans.py CHANGED
@@ -61,12 +61,15 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
 
 
 
 
 
64
  mask = (decoder_input_ids == eos_token_id)
65
  decoder_input_ids[mask] = 1
66
  decoder_attention_mask[mask] = 0
67
 
68
- labels = decoder_input_ids[:, 1:]
69
-
70
  return decoder_input_ids, decoder_attention_mask, labels
71
 
72
 
 
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
+ labels = decoder_input_ids[:, 1:]
65
+
66
+ labels_mask = labels == 1
67
+ labels[labels_mask] = -100
68
+
69
  mask = (decoder_input_ids == eos_token_id)
70
  decoder_input_ids[mask] = 1
71
  decoder_attention_mask[mask] = 0
72
 
 
 
73
  return decoder_input_ids, decoder_attention_mask, labels
74
 
75