ccdv commited on
Commit
ce82caa
1 Parent(s): 3406624

update files

Browse files
config.json CHANGED
@@ -10,12 +10,12 @@
10
  ],
11
  "attention_dropout": 0.1,
12
  "auto_map": {
13
- "AutoConfig": "modeling_lsg_barthez.LSGMBartConfig",
14
- "AutoModel": "modeling_lsg_barthez.LSGMBartModel",
15
- "AutoModelForCausalLM": "modeling_lsg_barthez.LSGMBartForCausalLM",
16
- "AutoModelForQuestionAnswering": "modeling_lsg_barthez.LSGMBartForQuestionAnswering",
17
- "AutoModelForSeq2SeqLM": "modeling_lsg_barthez.LSGMBartForConditionalGeneration",
18
- "AutoModelForSequenceClassification": "modeling_lsg_barthez.LSGMBartForSequenceClassification"
19
  },
20
  "base_model_prefix": "lsg",
21
  "block_size": 128,
@@ -71,7 +71,7 @@
71
  "static_position_embeddings": false,
72
  "tokenizer_class": "BarthezTokenizer",
73
  "torch_dtype": "float32",
74
- "transformers_version": "4.19.2",
75
  "use_cache": true,
76
  "vocab_size": 50002
77
  }
 
10
  ],
11
  "attention_dropout": 0.1,
12
  "auto_map": {
13
+ "AutoConfig": "modeling_lsg_mbart.LSGMBartConfig",
14
+ "AutoModel": "modeling_lsg_mbart.LSGMBartModel",
15
+ "AutoModelForCausalLM": "modeling_lsg_mbart.LSGMBartForCausalLM",
16
+ "AutoModelForQuestionAnswering": "modeling_lsg_mbart.LSGMBartForQuestionAnswering",
17
+ "AutoModelForSeq2SeqLM": "modeling_lsg_mbart.LSGMBartForConditionalGeneration",
18
+ "AutoModelForSequenceClassification": "modeling_lsg_mbart.LSGMBartForSequenceClassification"
19
  },
20
  "base_model_prefix": "lsg",
21
  "block_size": 128,
 
71
  "static_position_embeddings": false,
72
  "tokenizer_class": "BarthezTokenizer",
73
  "torch_dtype": "float32",
74
+ "transformers_version": "4.20.1",
75
  "use_cache": true,
76
  "vocab_size": 50002
77
  }
modeling_lsg_barthez.py → modeling_lsg_mbart.py RENAMED
@@ -1,21 +1,21 @@
1
  from logging import warn
2
  import torch
3
- from transformers.models.bart.modeling_bart import *
4
- from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
6
  import sys
7
 
8
  AUTO_MAP = {
9
- "AutoModel": "modeling_lsg_barthez.LSGMBartModel",
10
- "AutoModelForCausalLM": "modeling_lsg_barthez.LSGMBartForCausalLM",
11
- "AutoModelForQuestionAnswering": "modeling_lsg_barthez.LSGMBartForQuestionAnswering",
12
- "AutoModelForSequenceClassification": "modeling_lsg_barthez.LSGMBartForSequenceClassification",
13
- "AutoModelForSeq2SeqLM": "modeling_lsg_barthez.LSGMBartForConditionalGeneration"
14
  }
15
 
16
- class LSGMBartConfig(BartConfig):
17
  """
18
- This class overrides :class:`~transformers.BartConfig`. Please check the superclass for the appropriate
19
  documentation alongside usage examples.
20
  """
21
 
@@ -41,7 +41,7 @@ class LSGMBartConfig(BartConfig):
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
44
-
45
  self.adaptive = adaptive
46
  self.auto_map = AUTO_MAP
47
  self.base_model_prefix = base_model_prefix
@@ -81,7 +81,7 @@ class LSGMBartConfig(BartConfig):
81
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
82
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
83
 
84
-
85
  class BaseSelfAttention(nn.Module):
86
 
87
  def __init__(
@@ -265,7 +265,7 @@ class LSGAttentionProduct(nn.Module):
265
  s = (size - step) // 2
266
 
267
  # Pad before block reshaping
268
- if is_attn_mask:
269
  pad_value = torch.finfo(hidden_states.dtype).min
270
  hidden_states = hidden_states.transpose(-1, -2)
271
  else:
@@ -295,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
295
 
296
  # Pad before block reshaping
297
  if is_attn_mask:
298
- pad_value = torch.finfo(hidden_states.dtype).min
299
  hidden_states = hidden_states.transpose(-1, -2)
300
  else:
301
  pad_value = 0
@@ -376,7 +376,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
376
 
377
  if config.sparsity_type == "lsh":
378
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
379
-
380
  def get_sparse_tokens_with_norm(self, keys, values, mask):
381
 
382
  if self.sparsity_factor == 1:
@@ -490,7 +490,6 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
490
  values /= mask + 1e-8
491
 
492
  mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
493
-
494
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
495
 
496
  def lsh_round(self, keys, values, mask, output_size):
@@ -619,7 +618,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
619
  return x.reshape(n, h, -1, chunk_size, d)
620
 
621
 
622
- class LSGMBartEncoderLayer(BartEncoderLayer):
623
 
624
  def __init__(self, config):
625
 
@@ -632,14 +631,14 @@ class LSGMBartEncoderLayer(BartEncoderLayer):
632
  )
633
 
634
 
635
- class LSGMBartDecoderLayer(BartDecoderLayer):
636
 
637
  def __init__(self, config):
638
 
639
  super().__init__(config)
640
-
641
 
642
- class LSGMBartClassificationHead(BartClassificationHead):
 
643
  """Head for sentence-level classification tasks."""
644
 
645
  def __init__(
@@ -653,31 +652,23 @@ class LSGMBartClassificationHead(BartClassificationHead):
653
  super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
654
 
655
 
656
- class LSGMBartPretrainedModel(BartPretrainedModel):
657
 
658
  config_class = LSGMBartConfig
 
 
659
 
660
  def _set_gradient_checkpointing(self, module, value=False):
661
- print(isinstance(module, (BartDecoder, BartEncoder, LSGMBartDecoder, LSGMBartEncoder)))
662
- if isinstance(module, (BartDecoder, BartEncoder, LSGMBartDecoder, LSGMBartEncoder)):
663
  module.gradient_checkpointing = value
664
 
665
 
666
- class PretrainedLSGMBartModel(LSGMBartPretrainedModel):
667
-
668
- def __init_subclass__(self):
669
- warnings.warn(
670
- "The class `PretrainedBartModel` has been depreciated, please use `LSGMBartPretrainedModel` instead.",
671
- FutureWarning,
672
- )
673
-
674
-
675
- class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
676
  """
677
  Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
678
- :class:`BartEncoderLayer`.
679
  Args:
680
- config: BartConfig
681
  embed_tokens (nn.Embedding): output embedding
682
  """
683
 
@@ -697,12 +688,13 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
697
  else:
698
  self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
699
 
700
- self.embed_positions = BartLearnedPositionalEmbedding(
701
  config.max_position_embeddings,
702
  embed_dim,
703
  )
704
  self.layers = nn.ModuleList([LSGMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
705
  self.layernorm_embedding = nn.LayerNorm(embed_dim)
 
706
 
707
  #
708
  assert hasattr(config, "num_global_tokens")
@@ -740,8 +732,8 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
740
  if attention_mask is None:
741
  attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
742
  if self.mask_first_token:
743
- attention_mask[:,0] = 0
744
-
745
  b = self.block_size * 2
746
  pad = t % self.block_size
747
 
@@ -880,6 +872,8 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
880
  if output_attentions:
881
  all_attentions = all_attentions + (layer_outputs[1],)
882
 
 
 
883
  if output_hidden_states:
884
  encoder_states = encoder_states + (hidden_states,)
885
 
@@ -890,9 +884,9 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
890
  )
891
 
892
 
893
- class LSGMBartDecoder(LSGMBartPretrainedModel, BartDecoder):
894
  """
895
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGMBartDecoderLayer`
896
  Args:
897
  config: BartConfig
898
  embed_tokens (nn.Embedding): output embedding
@@ -914,20 +908,21 @@ class LSGMBartDecoder(LSGMBartPretrainedModel, BartDecoder):
914
  else:
915
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
916
 
917
- self.embed_positions = BartLearnedPositionalEmbedding(
918
  config.max_position_embeddings,
919
  config.d_model,
920
  )
921
  self.layers = nn.ModuleList([LSGMBartDecoderLayer(config) for _ in range(config.decoder_layers)])
922
  self.layernorm_embedding = nn.LayerNorm(config.d_model)
923
-
 
924
  self.gradient_checkpointing = False
925
 
926
  # Initialize weights and apply final processing
927
  self.post_init()
928
 
929
 
930
- class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
931
 
932
  def __init__(self, config):
933
 
@@ -964,13 +959,6 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
964
  return_dict=None,
965
  ):
966
 
967
- # different to other models, Bart automatically creates decoder_input_ids from
968
- # input_ids if no decoder_input_ids are provided
969
- if decoder_input_ids is None and decoder_inputs_embeds is None:
970
- decoder_input_ids = shift_tokens_right(
971
- input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
972
- )
973
-
974
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
975
  output_hidden_states = (
976
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -978,6 +966,11 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
978
  use_cache = use_cache if use_cache is not None else self.config.use_cache
979
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
980
 
 
 
 
 
 
981
  if encoder_outputs is None:
982
  encoder_outputs = self.encoder(
983
  input_ids=input_ids,
@@ -999,7 +992,7 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
999
  # Pad mask for global tokens
1000
  if self.pass_global_tokens_to_decoder and attention_mask is not None:
1001
  attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
1002
-
1003
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1004
  decoder_outputs = self.decoder(
1005
  input_ids=decoder_input_ids,
@@ -1031,10 +1024,15 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
1031
  )
1032
 
1033
 
1034
- class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, BartForConditionalGeneration):
1035
 
1036
  base_model_prefix = "model"
1037
- _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
 
 
 
 
 
1038
 
1039
  def __init__(self, config):
1040
 
@@ -1047,9 +1045,9 @@ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, BartForCondition
1047
  self.post_init()
1048
 
1049
 
1050
- class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, BartForSequenceClassification):
1051
 
1052
- def __init__(self, config: LSGMBartConfig, **kwargs):
1053
 
1054
  LSGMBartPretrainedModel.__init__(self, config, **kwargs)
1055
  self.model = LSGMBartModel(config)
@@ -1063,9 +1061,9 @@ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, BartForSequence
1063
  self.model._init_weights(self.classification_head.out_proj)
1064
 
1065
 
1066
- class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, BartForQuestionAnswering):
1067
 
1068
- def __init__(self, config: LSGMBartConfig):
1069
 
1070
  LSGMBartPretrainedModel.__init__(self, config)
1071
 
@@ -1077,14 +1075,14 @@ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, BartForQuestionAnswe
1077
 
1078
  self.model._init_weights(self.qa_outputs)
1079
 
1080
-
1081
  class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
1082
  """
1083
  This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1084
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1085
  """
1086
 
1087
- def __init__(self, config: LSGMBartConfig):
1088
  super().__init__(config)
1089
  self.decoder = LSGMBartDecoder(config)
1090
 
@@ -1092,9 +1090,9 @@ class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
1092
  return self.decoder(*args, **kwargs)
1093
 
1094
 
1095
- class LSGMBartForCausalLM(LSGMBartPretrainedModel, BartForCausalLM):
1096
 
1097
- def __init__(self, config: LSGMBartConfig):
1098
 
1099
  config = copy.deepcopy(config)
1100
  config.is_decoder = True
 
1
  from logging import warn
2
  import torch
3
+ from transformers.models.mbart.modeling_mbart import *
4
+ from transformers.models.mbart.modeling_mbart import _expand_mask
5
  import torch.nn as nn
6
  import sys
7
 
8
  AUTO_MAP = {
9
+ "AutoModel": "modeling_lsg_mbart.LSGMBartModel",
10
+ "AutoModelForCausalLM": "modeling_lsg_mbart.LSGMBartForCausalLM",
11
+ "AutoModelForQuestionAnswering": "modeling_lsg_mbart.LSGMBartForQuestionAnswering",
12
+ "AutoModelForSequenceClassification": "modeling_lsg_mbart.LSGMBartForSequenceClassification",
13
+ "AutoModelForSeq2SeqLM": "modeling_lsg_mbart.LSGMBartForConditionalGeneration"
14
  }
15
 
16
+ class LSGMBartConfig(MBartConfig):
17
  """
18
+ This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
19
  documentation alongside usage examples.
20
  """
21
 
 
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
44
+
45
  self.adaptive = adaptive
46
  self.auto_map = AUTO_MAP
47
  self.base_model_prefix = base_model_prefix
 
81
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
82
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
83
 
84
+
85
  class BaseSelfAttention(nn.Module):
86
 
87
  def __init__(
 
265
  s = (size - step) // 2
266
 
267
  # Pad before block reshaping
268
+ if is_attn_mask:
269
  pad_value = torch.finfo(hidden_states.dtype).min
270
  hidden_states = hidden_states.transpose(-1, -2)
271
  else:
 
295
 
296
  # Pad before block reshaping
297
  if is_attn_mask:
298
+ pad_value = torch.finfo(hidden_states.dtype).min
299
  hidden_states = hidden_states.transpose(-1, -2)
300
  else:
301
  pad_value = 0
 
376
 
377
  if config.sparsity_type == "lsh":
378
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
379
+
380
  def get_sparse_tokens_with_norm(self, keys, values, mask):
381
 
382
  if self.sparsity_factor == 1:
 
490
  values /= mask + 1e-8
491
 
492
  mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
493
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
494
 
495
  def lsh_round(self, keys, values, mask, output_size):
 
618
  return x.reshape(n, h, -1, chunk_size, d)
619
 
620
 
621
+ class LSGMBartEncoderLayer(MBartEncoderLayer):
622
 
623
  def __init__(self, config):
624
 
 
631
  )
632
 
633
 
634
+ class LSGMBartDecoderLayer(MBartDecoderLayer):
635
 
636
  def __init__(self, config):
637
 
638
  super().__init__(config)
 
639
 
640
+
641
+ class LSGMBartClassificationHead(MBartClassificationHead):
642
  """Head for sentence-level classification tasks."""
643
 
644
  def __init__(
 
652
  super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
653
 
654
 
655
+ class LSGMBartPretrainedModel(MBartPreTrainedModel):
656
 
657
  config_class = LSGMBartConfig
658
+ base_model_prefix = "model"
659
+ supports_gradient_checkpointing = True
660
 
661
  def _set_gradient_checkpointing(self, module, value=False):
662
+ if isinstance(module, (MBartDecoder, MBartEncoder, LSGMBartDecoder, LSGMBartEncoder)):
 
663
  module.gradient_checkpointing = value
664
 
665
 
666
+ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
 
 
 
 
 
 
 
 
 
667
  """
668
  Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
669
+ [`MBartEncoderLayer`].
670
  Args:
671
+ config: MBartConfig
672
  embed_tokens (nn.Embedding): output embedding
673
  """
674
 
 
688
  else:
689
  self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
690
 
691
+ self.embed_positions = MBartLearnedPositionalEmbedding(
692
  config.max_position_embeddings,
693
  embed_dim,
694
  )
695
  self.layers = nn.ModuleList([LSGMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
696
  self.layernorm_embedding = nn.LayerNorm(embed_dim)
697
+ self.layer_norm = nn.LayerNorm(config.d_model)
698
 
699
  #
700
  assert hasattr(config, "num_global_tokens")
 
732
  if attention_mask is None:
733
  attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
734
  if self.mask_first_token:
735
+ attention_mask[:, 0] = 0
736
+
737
  b = self.block_size * 2
738
  pad = t % self.block_size
739
 
 
872
  if output_attentions:
873
  all_attentions = all_attentions + (layer_outputs[1],)
874
 
875
+ hidden_states = self.layer_norm(hidden_states)
876
+
877
  if output_hidden_states:
878
  encoder_states = encoder_states + (hidden_states,)
879
 
 
884
  )
885
 
886
 
887
+ class LSGMBartDecoder(LSGMBartPretrainedModel, MBartDecoder):
888
  """
889
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
890
  Args:
891
  config: BartConfig
892
  embed_tokens (nn.Embedding): output embedding
 
908
  else:
909
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
910
 
911
+ self.embed_positions = MBartLearnedPositionalEmbedding(
912
  config.max_position_embeddings,
913
  config.d_model,
914
  )
915
  self.layers = nn.ModuleList([LSGMBartDecoderLayer(config) for _ in range(config.decoder_layers)])
916
  self.layernorm_embedding = nn.LayerNorm(config.d_model)
917
+ self.layer_norm = nn.LayerNorm(config.d_model)
918
+
919
  self.gradient_checkpointing = False
920
 
921
  # Initialize weights and apply final processing
922
  self.post_init()
923
 
924
 
925
+ class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
926
 
927
  def __init__(self, config):
928
 
 
959
  return_dict=None,
960
  ):
961
 
 
 
 
 
 
 
 
962
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
963
  output_hidden_states = (
964
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
966
  use_cache = use_cache if use_cache is not None else self.config.use_cache
967
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
968
 
969
+ # different to other models, MBart automatically creates decoder_input_ids from
970
+ # input_ids if no decoder_input_ids are provided
971
+ if decoder_input_ids is None:
972
+ decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
973
+
974
  if encoder_outputs is None:
975
  encoder_outputs = self.encoder(
976
  input_ids=input_ids,
 
992
  # Pad mask for global tokens
993
  if self.pass_global_tokens_to_decoder and attention_mask is not None:
994
  attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
995
+
996
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
997
  decoder_outputs = self.decoder(
998
  input_ids=decoder_input_ids,
 
1024
  )
1025
 
1026
 
1027
+ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, MBartForConditionalGeneration):
1028
 
1029
  base_model_prefix = "model"
1030
+ _keys_to_ignore_on_load_missing = [
1031
+ r"final_logits_bias",
1032
+ r"encoder.version",
1033
+ r"decoder.version",
1034
+ r"lm_head.weight",
1035
+ ]
1036
 
1037
  def __init__(self, config):
1038
 
 
1045
  self.post_init()
1046
 
1047
 
1048
+ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenceClassification):
1049
 
1050
+ def __init__(self, config, **kwargs):
1051
 
1052
  LSGMBartPretrainedModel.__init__(self, config, **kwargs)
1053
  self.model = LSGMBartModel(config)
 
1061
  self.model._init_weights(self.classification_head.out_proj)
1062
 
1063
 
1064
+ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnswering):
1065
 
1066
+ def __init__(self, config):
1067
 
1068
  LSGMBartPretrainedModel.__init__(self, config)
1069
 
 
1075
 
1076
  self.model._init_weights(self.qa_outputs)
1077
 
1078
+
1079
  class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
1080
  """
1081
  This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1082
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1083
  """
1084
 
1085
+ def __init__(self, config):
1086
  super().__init__(config)
1087
  self.decoder = LSGMBartDecoder(config)
1088
 
 
1090
  return self.decoder(*args, **kwargs)
1091
 
1092
 
1093
+ class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
1094
 
1095
+ def __init__(self, config):
1096
 
1097
  config = copy.deepcopy(config)
1098
  config.is_decoder = True
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d023eaea53d03dadf200fcac226bdc5a3e7349a2c20d718439d9c9cfb47e614e
3
- size 577604023
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6baea65e9a49a6c7d5e99b6622bc7da3af0da5e22c378953bd63eff2cb86390e
3
+ size 577617519