Upload 7 files
Browse files- modeling_indictrans.py +11 -5
modeling_indictrans.py
CHANGED
@@ -803,6 +803,12 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
803 |
# Initialize weights and apply final processing
|
804 |
self.post_init()
|
805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
def forward(
|
807 |
self,
|
808 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -1196,11 +1202,11 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1196 |
"""
|
1197 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1198 |
|
1199 |
-
if labels is not None:
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
-
|
1204 |
|
1205 |
outputs = self.model(
|
1206 |
input_ids,
|
|
|
803 |
# Initialize weights and apply final processing
|
804 |
self.post_init()
|
805 |
|
806 |
+
def get_input_embeddings(self):
|
807 |
+
return self.embed_tokens.word_embeddings
|
808 |
+
|
809 |
+
def set_input_embeddings(self, value):
|
810 |
+
self.embed_tokens.word_embeddings = value
|
811 |
+
|
812 |
def forward(
|
813 |
self,
|
814 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
1202 |
"""
|
1203 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1204 |
|
1205 |
+
# if labels is not None:
|
1206 |
+
# if decoder_input_ids is None:
|
1207 |
+
# decoder_input_ids = shift_tokens_right(
|
1208 |
+
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1209 |
+
# )
|
1210 |
|
1211 |
outputs = self.model(
|
1212 |
input_ids,
|