update files
Browse files- config.json +7 -7
- modeling_lsg_barthez.py → modeling_lsg_mbart.py +58 -60
- pytorch_model.bin +2 -2
config.json
CHANGED
@@ -10,12 +10,12 @@
|
|
10 |
],
|
11 |
"attention_dropout": 0.1,
|
12 |
"auto_map": {
|
13 |
-
"AutoConfig": "
|
14 |
-
"AutoModel": "
|
15 |
-
"AutoModelForCausalLM": "
|
16 |
-
"AutoModelForQuestionAnswering": "
|
17 |
-
"AutoModelForSeq2SeqLM": "
|
18 |
-
"AutoModelForSequenceClassification": "
|
19 |
},
|
20 |
"base_model_prefix": "lsg",
|
21 |
"block_size": 128,
|
@@ -71,7 +71,7 @@
|
|
71 |
"static_position_embeddings": false,
|
72 |
"tokenizer_class": "BarthezTokenizer",
|
73 |
"torch_dtype": "float32",
|
74 |
-
"transformers_version": "4.
|
75 |
"use_cache": true,
|
76 |
"vocab_size": 50002
|
77 |
}
|
|
|
10 |
],
|
11 |
"attention_dropout": 0.1,
|
12 |
"auto_map": {
|
13 |
+
"AutoConfig": "modeling_lsg_mbart.LSGMBartConfig",
|
14 |
+
"AutoModel": "modeling_lsg_mbart.LSGMBartModel",
|
15 |
+
"AutoModelForCausalLM": "modeling_lsg_mbart.LSGMBartForCausalLM",
|
16 |
+
"AutoModelForQuestionAnswering": "modeling_lsg_mbart.LSGMBartForQuestionAnswering",
|
17 |
+
"AutoModelForSeq2SeqLM": "modeling_lsg_mbart.LSGMBartForConditionalGeneration",
|
18 |
+
"AutoModelForSequenceClassification": "modeling_lsg_mbart.LSGMBartForSequenceClassification"
|
19 |
},
|
20 |
"base_model_prefix": "lsg",
|
21 |
"block_size": 128,
|
|
|
71 |
"static_position_embeddings": false,
|
72 |
"tokenizer_class": "BarthezTokenizer",
|
73 |
"torch_dtype": "float32",
|
74 |
+
"transformers_version": "4.20.1",
|
75 |
"use_cache": true,
|
76 |
"vocab_size": 50002
|
77 |
}
|
modeling_lsg_barthez.py → modeling_lsg_mbart.py
RENAMED
@@ -1,21 +1,21 @@
|
|
1 |
from logging import warn
|
2 |
import torch
|
3 |
-
from transformers.models.
|
4 |
-
from transformers.models.
|
5 |
import torch.nn as nn
|
6 |
import sys
|
7 |
|
8 |
AUTO_MAP = {
|
9 |
-
"AutoModel": "
|
10 |
-
"AutoModelForCausalLM": "
|
11 |
-
"AutoModelForQuestionAnswering": "
|
12 |
-
"AutoModelForSequenceClassification": "
|
13 |
-
"AutoModelForSeq2SeqLM": "
|
14 |
}
|
15 |
|
16 |
-
class LSGMBartConfig(
|
17 |
"""
|
18 |
-
This class overrides :class:`~transformers.
|
19 |
documentation alongside usage examples.
|
20 |
"""
|
21 |
|
@@ -41,7 +41,7 @@ class LSGMBartConfig(BartConfig):
|
|
41 |
):
|
42 |
"""Constructs LSGConfig."""
|
43 |
super().__init__(**kwargs)
|
44 |
-
|
45 |
self.adaptive = adaptive
|
46 |
self.auto_map = AUTO_MAP
|
47 |
self.base_model_prefix = base_model_prefix
|
@@ -81,7 +81,7 @@ class LSGMBartConfig(BartConfig):
|
|
81 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
82 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
83 |
|
84 |
-
|
85 |
class BaseSelfAttention(nn.Module):
|
86 |
|
87 |
def __init__(
|
@@ -265,7 +265,7 @@ class LSGAttentionProduct(nn.Module):
|
|
265 |
s = (size - step) // 2
|
266 |
|
267 |
# Pad before block reshaping
|
268 |
-
if is_attn_mask:
|
269 |
pad_value = torch.finfo(hidden_states.dtype).min
|
270 |
hidden_states = hidden_states.transpose(-1, -2)
|
271 |
else:
|
@@ -295,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
|
|
295 |
|
296 |
# Pad before block reshaping
|
297 |
if is_attn_mask:
|
298 |
-
pad_value = torch.finfo(hidden_states.dtype).min
|
299 |
hidden_states = hidden_states.transpose(-1, -2)
|
300 |
else:
|
301 |
pad_value = 0
|
@@ -376,7 +376,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
|
|
376 |
|
377 |
if config.sparsity_type == "lsh":
|
378 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
379 |
-
|
380 |
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
381 |
|
382 |
if self.sparsity_factor == 1:
|
@@ -490,7 +490,6 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
|
|
490 |
values /= mask + 1e-8
|
491 |
|
492 |
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
493 |
-
|
494 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
495 |
|
496 |
def lsh_round(self, keys, values, mask, output_size):
|
@@ -619,7 +618,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
|
|
619 |
return x.reshape(n, h, -1, chunk_size, d)
|
620 |
|
621 |
|
622 |
-
class LSGMBartEncoderLayer(
|
623 |
|
624 |
def __init__(self, config):
|
625 |
|
@@ -632,14 +631,14 @@ class LSGMBartEncoderLayer(BartEncoderLayer):
|
|
632 |
)
|
633 |
|
634 |
|
635 |
-
class LSGMBartDecoderLayer(
|
636 |
|
637 |
def __init__(self, config):
|
638 |
|
639 |
super().__init__(config)
|
640 |
-
|
641 |
|
642 |
-
|
|
|
643 |
"""Head for sentence-level classification tasks."""
|
644 |
|
645 |
def __init__(
|
@@ -653,31 +652,23 @@ class LSGMBartClassificationHead(BartClassificationHead):
|
|
653 |
super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
|
654 |
|
655 |
|
656 |
-
class LSGMBartPretrainedModel(
|
657 |
|
658 |
config_class = LSGMBartConfig
|
|
|
|
|
659 |
|
660 |
def _set_gradient_checkpointing(self, module, value=False):
|
661 |
-
|
662 |
-
if isinstance(module, (BartDecoder, BartEncoder, LSGMBartDecoder, LSGMBartEncoder)):
|
663 |
module.gradient_checkpointing = value
|
664 |
|
665 |
|
666 |
-
class
|
667 |
-
|
668 |
-
def __init_subclass__(self):
|
669 |
-
warnings.warn(
|
670 |
-
"The class `PretrainedBartModel` has been depreciated, please use `LSGMBartPretrainedModel` instead.",
|
671 |
-
FutureWarning,
|
672 |
-
)
|
673 |
-
|
674 |
-
|
675 |
-
class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
|
676 |
"""
|
677 |
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
678 |
-
|
679 |
Args:
|
680 |
-
config:
|
681 |
embed_tokens (nn.Embedding): output embedding
|
682 |
"""
|
683 |
|
@@ -697,12 +688,13 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
|
|
697 |
else:
|
698 |
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
699 |
|
700 |
-
self.embed_positions =
|
701 |
config.max_position_embeddings,
|
702 |
embed_dim,
|
703 |
)
|
704 |
self.layers = nn.ModuleList([LSGMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
705 |
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
|
|
706 |
|
707 |
#
|
708 |
assert hasattr(config, "num_global_tokens")
|
@@ -740,8 +732,8 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
|
|
740 |
if attention_mask is None:
|
741 |
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
742 |
if self.mask_first_token:
|
743 |
-
attention_mask[:,0] = 0
|
744 |
-
|
745 |
b = self.block_size * 2
|
746 |
pad = t % self.block_size
|
747 |
|
@@ -880,6 +872,8 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
|
|
880 |
if output_attentions:
|
881 |
all_attentions = all_attentions + (layer_outputs[1],)
|
882 |
|
|
|
|
|
883 |
if output_hidden_states:
|
884 |
encoder_states = encoder_states + (hidden_states,)
|
885 |
|
@@ -890,9 +884,9 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
|
|
890 |
)
|
891 |
|
892 |
|
893 |
-
class LSGMBartDecoder(LSGMBartPretrainedModel,
|
894 |
"""
|
895 |
-
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`
|
896 |
Args:
|
897 |
config: BartConfig
|
898 |
embed_tokens (nn.Embedding): output embedding
|
@@ -914,20 +908,21 @@ class LSGMBartDecoder(LSGMBartPretrainedModel, BartDecoder):
|
|
914 |
else:
|
915 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
916 |
|
917 |
-
self.embed_positions =
|
918 |
config.max_position_embeddings,
|
919 |
config.d_model,
|
920 |
)
|
921 |
self.layers = nn.ModuleList([LSGMBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
922 |
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
923 |
-
|
|
|
924 |
self.gradient_checkpointing = False
|
925 |
|
926 |
# Initialize weights and apply final processing
|
927 |
self.post_init()
|
928 |
|
929 |
|
930 |
-
class LSGMBartModel(LSGMBartPretrainedModel,
|
931 |
|
932 |
def __init__(self, config):
|
933 |
|
@@ -964,13 +959,6 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
|
|
964 |
return_dict=None,
|
965 |
):
|
966 |
|
967 |
-
# different to other models, Bart automatically creates decoder_input_ids from
|
968 |
-
# input_ids if no decoder_input_ids are provided
|
969 |
-
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
970 |
-
decoder_input_ids = shift_tokens_right(
|
971 |
-
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
|
972 |
-
)
|
973 |
-
|
974 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
975 |
output_hidden_states = (
|
976 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -978,6 +966,11 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
|
|
978 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
979 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
980 |
|
|
|
|
|
|
|
|
|
|
|
981 |
if encoder_outputs is None:
|
982 |
encoder_outputs = self.encoder(
|
983 |
input_ids=input_ids,
|
@@ -999,7 +992,7 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
|
|
999 |
# Pad mask for global tokens
|
1000 |
if self.pass_global_tokens_to_decoder and attention_mask is not None:
|
1001 |
attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
|
1002 |
-
|
1003 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
1004 |
decoder_outputs = self.decoder(
|
1005 |
input_ids=decoder_input_ids,
|
@@ -1031,10 +1024,15 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
|
|
1031 |
)
|
1032 |
|
1033 |
|
1034 |
-
class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel,
|
1035 |
|
1036 |
base_model_prefix = "model"
|
1037 |
-
_keys_to_ignore_on_load_missing = [
|
|
|
|
|
|
|
|
|
|
|
1038 |
|
1039 |
def __init__(self, config):
|
1040 |
|
@@ -1047,9 +1045,9 @@ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, BartForCondition
|
|
1047 |
self.post_init()
|
1048 |
|
1049 |
|
1050 |
-
class LSGMBartForSequenceClassification(LSGMBartPretrainedModel,
|
1051 |
|
1052 |
-
def __init__(self, config
|
1053 |
|
1054 |
LSGMBartPretrainedModel.__init__(self, config, **kwargs)
|
1055 |
self.model = LSGMBartModel(config)
|
@@ -1063,9 +1061,9 @@ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, BartForSequence
|
|
1063 |
self.model._init_weights(self.classification_head.out_proj)
|
1064 |
|
1065 |
|
1066 |
-
class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel,
|
1067 |
|
1068 |
-
def __init__(self, config
|
1069 |
|
1070 |
LSGMBartPretrainedModel.__init__(self, config)
|
1071 |
|
@@ -1077,14 +1075,14 @@ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, BartForQuestionAnswe
|
|
1077 |
|
1078 |
self.model._init_weights(self.qa_outputs)
|
1079 |
|
1080 |
-
|
1081 |
class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
|
1082 |
"""
|
1083 |
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
1084 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
1085 |
"""
|
1086 |
|
1087 |
-
def __init__(self, config
|
1088 |
super().__init__(config)
|
1089 |
self.decoder = LSGMBartDecoder(config)
|
1090 |
|
@@ -1092,9 +1090,9 @@ class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
|
|
1092 |
return self.decoder(*args, **kwargs)
|
1093 |
|
1094 |
|
1095 |
-
class LSGMBartForCausalLM(LSGMBartPretrainedModel,
|
1096 |
|
1097 |
-
def __init__(self, config
|
1098 |
|
1099 |
config = copy.deepcopy(config)
|
1100 |
config.is_decoder = True
|
|
|
1 |
from logging import warn
|
2 |
import torch
|
3 |
+
from transformers.models.mbart.modeling_mbart import *
|
4 |
+
from transformers.models.mbart.modeling_mbart import _expand_mask
|
5 |
import torch.nn as nn
|
6 |
import sys
|
7 |
|
8 |
AUTO_MAP = {
|
9 |
+
"AutoModel": "modeling_lsg_mbart.LSGMBartModel",
|
10 |
+
"AutoModelForCausalLM": "modeling_lsg_mbart.LSGMBartForCausalLM",
|
11 |
+
"AutoModelForQuestionAnswering": "modeling_lsg_mbart.LSGMBartForQuestionAnswering",
|
12 |
+
"AutoModelForSequenceClassification": "modeling_lsg_mbart.LSGMBartForSequenceClassification",
|
13 |
+
"AutoModelForSeq2SeqLM": "modeling_lsg_mbart.LSGMBartForConditionalGeneration"
|
14 |
}
|
15 |
|
16 |
+
class LSGMBartConfig(MBartConfig):
|
17 |
"""
|
18 |
+
This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
|
19 |
documentation alongside usage examples.
|
20 |
"""
|
21 |
|
|
|
41 |
):
|
42 |
"""Constructs LSGConfig."""
|
43 |
super().__init__(**kwargs)
|
44 |
+
|
45 |
self.adaptive = adaptive
|
46 |
self.auto_map = AUTO_MAP
|
47 |
self.base_model_prefix = base_model_prefix
|
|
|
81 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
82 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
83 |
|
84 |
+
|
85 |
class BaseSelfAttention(nn.Module):
|
86 |
|
87 |
def __init__(
|
|
|
265 |
s = (size - step) // 2
|
266 |
|
267 |
# Pad before block reshaping
|
268 |
+
if is_attn_mask:
|
269 |
pad_value = torch.finfo(hidden_states.dtype).min
|
270 |
hidden_states = hidden_states.transpose(-1, -2)
|
271 |
else:
|
|
|
295 |
|
296 |
# Pad before block reshaping
|
297 |
if is_attn_mask:
|
298 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
299 |
hidden_states = hidden_states.transpose(-1, -2)
|
300 |
else:
|
301 |
pad_value = 0
|
|
|
376 |
|
377 |
if config.sparsity_type == "lsh":
|
378 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
379 |
+
|
380 |
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
381 |
|
382 |
if self.sparsity_factor == 1:
|
|
|
490 |
values /= mask + 1e-8
|
491 |
|
492 |
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
|
|
493 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
494 |
|
495 |
def lsh_round(self, keys, values, mask, output_size):
|
|
|
618 |
return x.reshape(n, h, -1, chunk_size, d)
|
619 |
|
620 |
|
621 |
+
class LSGMBartEncoderLayer(MBartEncoderLayer):
|
622 |
|
623 |
def __init__(self, config):
|
624 |
|
|
|
631 |
)
|
632 |
|
633 |
|
634 |
+
class LSGMBartDecoderLayer(MBartDecoderLayer):
|
635 |
|
636 |
def __init__(self, config):
|
637 |
|
638 |
super().__init__(config)
|
|
|
639 |
|
640 |
+
|
641 |
+
class LSGMBartClassificationHead(MBartClassificationHead):
|
642 |
"""Head for sentence-level classification tasks."""
|
643 |
|
644 |
def __init__(
|
|
|
652 |
super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
|
653 |
|
654 |
|
655 |
+
class LSGMBartPretrainedModel(MBartPreTrainedModel):
|
656 |
|
657 |
config_class = LSGMBartConfig
|
658 |
+
base_model_prefix = "model"
|
659 |
+
supports_gradient_checkpointing = True
|
660 |
|
661 |
def _set_gradient_checkpointing(self, module, value=False):
|
662 |
+
if isinstance(module, (MBartDecoder, MBartEncoder, LSGMBartDecoder, LSGMBartEncoder)):
|
|
|
663 |
module.gradient_checkpointing = value
|
664 |
|
665 |
|
666 |
+
class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
"""
|
668 |
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
669 |
+
[`MBartEncoderLayer`].
|
670 |
Args:
|
671 |
+
config: MBartConfig
|
672 |
embed_tokens (nn.Embedding): output embedding
|
673 |
"""
|
674 |
|
|
|
688 |
else:
|
689 |
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
690 |
|
691 |
+
self.embed_positions = MBartLearnedPositionalEmbedding(
|
692 |
config.max_position_embeddings,
|
693 |
embed_dim,
|
694 |
)
|
695 |
self.layers = nn.ModuleList([LSGMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
696 |
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
697 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
698 |
|
699 |
#
|
700 |
assert hasattr(config, "num_global_tokens")
|
|
|
732 |
if attention_mask is None:
|
733 |
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
734 |
if self.mask_first_token:
|
735 |
+
attention_mask[:, 0] = 0
|
736 |
+
|
737 |
b = self.block_size * 2
|
738 |
pad = t % self.block_size
|
739 |
|
|
|
872 |
if output_attentions:
|
873 |
all_attentions = all_attentions + (layer_outputs[1],)
|
874 |
|
875 |
+
hidden_states = self.layer_norm(hidden_states)
|
876 |
+
|
877 |
if output_hidden_states:
|
878 |
encoder_states = encoder_states + (hidden_states,)
|
879 |
|
|
|
884 |
)
|
885 |
|
886 |
|
887 |
+
class LSGMBartDecoder(LSGMBartPretrainedModel, MBartDecoder):
|
888 |
"""
|
889 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
890 |
Args:
|
891 |
config: BartConfig
|
892 |
embed_tokens (nn.Embedding): output embedding
|
|
|
908 |
else:
|
909 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
910 |
|
911 |
+
self.embed_positions = MBartLearnedPositionalEmbedding(
|
912 |
config.max_position_embeddings,
|
913 |
config.d_model,
|
914 |
)
|
915 |
self.layers = nn.ModuleList([LSGMBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
916 |
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
917 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
918 |
+
|
919 |
self.gradient_checkpointing = False
|
920 |
|
921 |
# Initialize weights and apply final processing
|
922 |
self.post_init()
|
923 |
|
924 |
|
925 |
+
class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
|
926 |
|
927 |
def __init__(self, config):
|
928 |
|
|
|
959 |
return_dict=None,
|
960 |
):
|
961 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
962 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
963 |
output_hidden_states = (
|
964 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
966 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
967 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
968 |
|
969 |
+
# different to other models, MBart automatically creates decoder_input_ids from
|
970 |
+
# input_ids if no decoder_input_ids are provided
|
971 |
+
if decoder_input_ids is None:
|
972 |
+
decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
|
973 |
+
|
974 |
if encoder_outputs is None:
|
975 |
encoder_outputs = self.encoder(
|
976 |
input_ids=input_ids,
|
|
|
992 |
# Pad mask for global tokens
|
993 |
if self.pass_global_tokens_to_decoder and attention_mask is not None:
|
994 |
attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
|
995 |
+
|
996 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
997 |
decoder_outputs = self.decoder(
|
998 |
input_ids=decoder_input_ids,
|
|
|
1024 |
)
|
1025 |
|
1026 |
|
1027 |
+
class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, MBartForConditionalGeneration):
|
1028 |
|
1029 |
base_model_prefix = "model"
|
1030 |
+
_keys_to_ignore_on_load_missing = [
|
1031 |
+
r"final_logits_bias",
|
1032 |
+
r"encoder.version",
|
1033 |
+
r"decoder.version",
|
1034 |
+
r"lm_head.weight",
|
1035 |
+
]
|
1036 |
|
1037 |
def __init__(self, config):
|
1038 |
|
|
|
1045 |
self.post_init()
|
1046 |
|
1047 |
|
1048 |
+
class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenceClassification):
|
1049 |
|
1050 |
+
def __init__(self, config, **kwargs):
|
1051 |
|
1052 |
LSGMBartPretrainedModel.__init__(self, config, **kwargs)
|
1053 |
self.model = LSGMBartModel(config)
|
|
|
1061 |
self.model._init_weights(self.classification_head.out_proj)
|
1062 |
|
1063 |
|
1064 |
+
class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnswering):
|
1065 |
|
1066 |
+
def __init__(self, config):
|
1067 |
|
1068 |
LSGMBartPretrainedModel.__init__(self, config)
|
1069 |
|
|
|
1075 |
|
1076 |
self.model._init_weights(self.qa_outputs)
|
1077 |
|
1078 |
+
|
1079 |
class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
|
1080 |
"""
|
1081 |
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
1082 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
1083 |
"""
|
1084 |
|
1085 |
+
def __init__(self, config):
|
1086 |
super().__init__(config)
|
1087 |
self.decoder = LSGMBartDecoder(config)
|
1088 |
|
|
|
1090 |
return self.decoder(*args, **kwargs)
|
1091 |
|
1092 |
|
1093 |
+
class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
|
1094 |
|
1095 |
+
def __init__(self, config):
|
1096 |
|
1097 |
config = copy.deepcopy(config)
|
1098 |
config.is_decoder = True
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6baea65e9a49a6c7d5e99b6622bc7da3af0da5e22c378953bd63eff2cb86390e
|
3 |
+
size 577617519
|