pranjalchitale commited on
Commit
4000789
·
verified ·
1 Parent(s): a765fec

Fixes tie weights.

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +9 -8
modeling_indictrans.py CHANGED
@@ -1643,7 +1643,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
1643
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1644
  class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
1645
  base_model_prefix = "model"
1646
- _tied_weights_keys = None
1647
  _label_smoothing = 0.0
1648
 
1649
  def __init__(self, config: IndicTransConfig):
@@ -1653,19 +1653,20 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
1653
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1654
  )
1655
 
1656
- if config.share_decoder_input_output_embed:
1657
- self.lm_head.weight = self.model.decoder.embed_tokens.weight
1658
-
1659
  self.post_init()
1660
 
1661
- def tie_weights(self):
1662
- pass
 
1663
 
1664
  def get_encoder(self):
1665
- return self.model.get_encoder()
1666
 
1667
  def get_decoder(self):
1668
- return self.model.get_decoder()
 
 
 
1669
 
1670
  def get_output_embeddings(self):
1671
  return self.lm_head
 
1643
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1644
  class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
1645
  base_model_prefix = "model"
1646
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
1647
  _label_smoothing = 0.0
1648
 
1649
  def __init__(self, config: IndicTransConfig):
 
1653
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1654
  )
1655
 
 
 
 
1656
  self.post_init()
1657
 
1658
+ def tie_weights(self):
1659
+       if self.config.share_decoder_input_output_embed:
1660
+           self._tie_or_clone_weights(self.decoder.embed_tokens, self.lm_head)
1661
 
1662
  def get_encoder(self):
1663
+ return self.model.encoder
1664
 
1665
  def get_decoder(self):
1666
+ return self.model.decoder
1667
+
1668
+ def get_input_embeddings(self):
1669
+ return self.model.encoder.embed_tokens
1670
 
1671
  def get_output_embeddings(self):
1672
  return self.lm_head