Raghavan commited on
Commit
bb8ed51
1 Parent(s): 6b39eaa

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +11 -5
modeling_indictrans.py CHANGED
@@ -803,6 +803,12 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
803
  # Initialize weights and apply final processing
804
  self.post_init()
805
 
 
 
 
 
 
 
806
  def forward(
807
  self,
808
  input_ids: Optional[torch.Tensor] = None,
@@ -1196,11 +1202,11 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1196
  """
1197
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1198
 
1199
- if labels is not None:
1200
- if decoder_input_ids is None:
1201
- decoder_input_ids = shift_tokens_right(
1202
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
1203
- )
1204
 
1205
  outputs = self.model(
1206
  input_ids,
 
803
  # Initialize weights and apply final processing
804
  self.post_init()
805
 
806
+ def get_input_embeddings(self):
807
+ return self.embed_tokens.word_embeddings
808
+
809
+ def set_input_embeddings(self, value):
810
+ self.embed_tokens.word_embeddings = value
811
+
812
  def forward(
813
  self,
814
  input_ids: Optional[torch.Tensor] = None,
 
1202
  """
1203
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1204
 
1205
+ # if labels is not None:
1206
+ # if decoder_input_ids is None:
1207
+ # decoder_input_ids = shift_tokens_right(
1208
+ # labels, self.config.pad_token_id, self.config.decoder_start_token_id
1209
+ # )
1210
 
1211
  outputs = self.model(
1212
  input_ids,