Raghavan commited on
Commit
d03c1e2
1 Parent(s): dd5aa99

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +9 -6
modeling_indictrans.py CHANGED
@@ -61,20 +61,23 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
- labels = torch.full(decoder_input_ids.size(),-100)
65
- labels[:, :-1] = decoder_input_ids[:, 1:]
 
 
 
66
 
67
  labels_mask = labels == 1
68
  labels[labels_mask] = -100
69
 
70
- mask = (decoder_input_ids == eos_token_id)
71
- decoder_input_ids[mask] = 1
72
- decoder_attention_mask[mask] = 0
73
-
74
 
75
  return decoder_input_ids, decoder_attention_mask, labels
76
 
77
 
 
78
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
79
  def _make_causal_mask(
80
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
+ new_decoder_input_ids = decoder_input_ids.clone().detach()
65
+ new_decoder_attention_mask = decoder_attention_mask.clone().detach()
66
+
67
+ labels = torch.full(new_decoder_input_ids.size(),-100)
68
+ labels[:, :-1] = new_decoder_input_ids[:, 1:]
69
 
70
  labels_mask = labels == 1
71
  labels[labels_mask] = -100
72
 
73
+ mask = (new_decoder_input_ids == eos_token_id)
74
+ new_decoder_input_ids[mask] = 1
75
+ new_decoder_attention_mask[mask] = 0
 
76
 
77
  return decoder_input_ids, decoder_attention_mask, labels
78
 
79
 
80
+
81
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
82
  def _make_causal_mask(
83
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0