Raghavan commited on
Commit
512b18c
1 Parent(s): 929b4c7

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +1 -1
modeling_indictrans.py CHANGED
@@ -61,7 +61,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
- labels = decoder_input_ids.full_size(decoder_input_ids.size(), -100)
65
  labels[:, :-1] = decoder_input_ids[:, 1:]
66
 
67
  labels_mask = labels == 1
 
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
+ labels = torch.full(decoder_input_ids.size(),-100)
65
  labels[:, :-1] = decoder_input_ids[:, 1:]
66
 
67
  labels_mask = labels == 1